R/tune_deconv.R

Defines functions plot_biasvar plot_tune best_nsubclass tune_stats summary.tune_deconv tune_dec tune_deconv

Documented in plot_biasvar plot_tune summary.tune_deconv tune_deconv

#' Tune deconvolution parameters
#' 
#' Performs an exhaustive grid search over a tuning grid of cell marker and
#' deconvolution parameters for either [updateMarkers()] (e.g. `expfilter` or
#' `nsubclass`) or [deconvolute()] (e.g. `comp_amount`).
#' 
#' @param mk cellMarkers class object
#' @param test matrix of bulk RNA-Seq to be deconvoluted. Passed to
#'   [deconvolute()].
#' @param samples matrix of cell amounts with subclasses in columns and samples
#'   in rows. Note that if this has been generated by [simulate_bulk()], using a
#'   value of `times` other than 1, then it is important that this is adjusted
#'   for here.
#' @param grid Named list of vectors for the tuning grid similar to
#'   [expand.grid()]. Names represent the parameter to be tuned which must be an
#'   argument in either [updateMarkers()] or [deconvolute()]. The elements of
#'   each vector are the values to be tuned for each parameter.
#' @param output Character value, either `"output"` or `"percent"` specifying
#'   which output from the subclass results element resulting from a call to
#'   [deconvolute()]. This deconvolution result is compared against the actual
#'   sample cell numbers in `samples`, using [metric_set()].
#' @param metric Specifies tuning metric to choose optimal tune: either
#'  "RMSE", "Rsq", "pearson" or "resvar" (residual variance of bulk gene 
#'  expression).
#' @param method Either "top" or "overall". Determines how best parameter values
#'   are chosen. With "top" the single top configuration is chosen. With
#'   "overall", the average effect of varying each parameter is calculated using
#'   the mean R-squared across all variations of other parameters. This can give
#'   a more stable choice of final tuning.
#' @param verbose Logical whether to show progress.
#' @param cores Number of cores for parallelisation via [parallel::mclapply()].
#'   Parallelisation is not available on windows.
#' @param ... Optional arguments passed to [deconvolute()] to control fixed
#'   settings.
#' @returns Dataframe with class `'tune_deconv'` whose columns include: the
#'   parameters being tuned via `grid`, cell subclass and R squared, RMSE,
#'   pearson r^2, residual gene expression variance, prediction bias and
#'   variance, and kappa (condition number).
#' @details
#' Tuning plots on the resulting object can be visualised using [plot_tune()].
#' If `best_tune` is set to "overall", this corresponds to setting 
#' `subclass = NULL` in [plot_tune()].
#'
#' Once the results output has been generated, arguments such as `metric` or
#' `method` can be changed to see different best tunes using `summary()` (see
#' [summary.tune_deconv()]).
#' 
#' `test` and `samples` matrices can be generated by [simulate_bulk()] and
#' [generate_samples()] based on the original scRNA-Seq count dataset.
#' 
#' @seealso [plot_tune()] [summary.tune_deconv()]
#' @importFrom stats aggregate
#' @export
tune_deconv <- function(mk, test, samples, grid,
                        output = "output",
                        metric = "RMSE",
                        method = "top",
                        verbose = TRUE, cores = 1, ...) {
  method <- match.arg(method, c("top", "overall"))
  metric <- match.arg(metric, c("RMSE", "Rsq", "pearson.rsq"))
  if (!inherits(mk, "cellMarkers")) stop("`mk` is not a cellMarkers objects")
  if (ncol(test) != nrow(samples)) stop("incompatible test and samples")
  if (!identical(colnames(mk$genemeans), colnames(samples)))
    stop("incompatible subclasses between mk and samples")
  
  grid[c("verbose", "cores")] <- NULL
  params <- names(grid)
  arg_set1 <- names(formals(updateMarkers))
  arg_set2 <- names(formals(deconvolute))
  if (any(!params %in% c(arg_set1, arg_set2)))
    stop("unknown tuning parameter in `grid`")
  dots <- list(...)
  if (any(params %in% names(dots))) stop("argument in `grid` also passed in `...`")
  w1 <- which(params %in% arg_set1)
  w2 <- which(params %in% arg_set2)
  grid2 <- if (length(w2) > 0) expand.grid(grid[w2], stringsAsFactors = FALSE) else NULL
  if (verbose) {
    message("Tuning parameters: ", paste(params, collapse = ", "))
  }
  
  # disable group analysis
  mk$group_geneset <- mk$group_angle <- mk$groupmeans <- NULL
  mk$groupmeans_filtered <- NULL
  
  if (length(w1) > 0) {
    grid1 <- expand.grid(grid[w1])
    res <- pmclapply(seq_len(nrow(grid1)), function(i) {
      args <- list(object = mk, verbose = NA)
      grid1_row <- grid1[i, , drop = FALSE]
      args <- c(args, grid1_row)
      mk_update <- do.call("updateMarkers", args) |> suppressMessages()
      df2 <- tune_dec(mk_update, test, samples, grid2, output, ...)
      data.frame(grid1_row, df2, row.names = NULL)
    }, progress = verbose, mc.cores = cores, mc.preschedule = FALSE)
    res <- do.call(rbind, res)
  } else {
    # null grid1
    if (is.null(grid2)) stop("No parameters to tune")
    res <- tune_dec(mk, test, samples, grid2, output, verbose, cores, ...)
  }
  res$subclass <- factor(res$subclass, levels = names(mk$cell_table))
  
  if (method == "top") {
    mres <- aggregate(res[, metric], by = res[, params, drop = FALSE],
                      FUN = mean, na.rm = TRUE)
    w <- if (metric == "RMSE") {which.min(mres$x)
    } else which.max(mres$x)
    colnames(mres)[which(colnames(mres) == "x")] <- paste0("mean.", metric)
    best_tune <- mres[w, ]
  } else {
    best_tune <- lapply(params, function(i) {
      mres <- aggregate(res[, metric], by = res[, i, drop = FALSE], FUN = mean,
                        na.rm = TRUE)
      w <- which.max(mres$x)
      mres[w, i]
    })
    best_tune <- data.frame(best_tune)
    colnames(best_tune) <- params
  }
  if (verbose) {
    message("Best tune:")
    print(best_tune, row.names = FALSE, digits = max(3, getOption("digits")-3),
          print.gap = 2L)
  }
  
  attr(res, "tune") <- best_tune
  attr(res, "metric") <- metric
  attr(res, "method") <- method
  class(res) <- c("tune_deconv", class(res))
  res
}


# tune inner grid of arguments for deconvolute()
tune_dec <- function(mk, test, samples, grid2, output, progress = FALSE,
                     cores = 1L, ...) {
  if (is.null(grid2)) {
    fit <- deconvolute(mk, test, verbose = FALSE, ...) |>
      suppressMessages()
    fit_output <- fit$subclass[[output]]
    out <- metric_set(samples, fit_output)
    ngene <- length(fit$mk$geneset) - length(fit$subclass$removed)
    df <- data.frame(subclass = rownames(out), ngene,
                     resvar = mean(fit$subclass$resvar), row.names = NULL)
    df <- cbind(df, out)
    return(df)
  }
  # loop grid2
  dots <- list(...)
  args <- list(mk = mk, test = test, verbose = FALSE)
  res <- pmclapply(seq_len(nrow(grid2)), function(i) {
    grid2_row <- grid2[i, , drop = FALSE]
    args <- c(args, grid2_row)
    if (length(dots)) args[names(dots)] <- dots
    fit <- do.call("deconvolute", args) |> suppressMessages()
    fit_output <- fit$subclass[[output]]
    out <- metric_set(samples, fit_output)
    ngene <- length(fit$mk$geneset) - length(fit$subclass$removed)
    df <- data.frame(grid2_row, subclass = rownames(out), ngene,
                     resvar = mean(fit$subclass$resvar),
                     kappa = kappa(fit$subclass$spillover), row.names = NULL)
    cbind(df, out)
  }, progress = progress, mc.cores = cores, mc.preschedule = FALSE)
  do.call(rbind, res)
}


#' Summarising deconvolution tuning
#' 
#' `summary` method for class `'tune_deconv'`.
#' 
#' @param object dataframe of class `'tune_deconv'`.
#' @param metric Specifies tuning metric to choose optimal tune: either
#'   "RMSE", "Rsq", "pearson" or "resvar" (residual variance of bulk gene 
#'  expression).
#' @param method Either "top" or "overall". Determines how best parameter values
#'   are chosen. With "top" the single top configuration is chosen. With
#'   "overall", the average effect of varying each parameter is calculated using
#'   the mean R-squared across all variations of other parameters. This can give
#'   a more stable choice of final tuning.
#' @param ... further arguments passed to other methods.
#' @returns If `method = "top"` prints the row representing the best tuning of
#'   parameters (maximum mean R squared, averaged across subclasses). For method
#'   = "overall", the average effect of varying each parameter is calculated by
#'   mean R-squared across the rest of the grid and the best value for each
#'   parameter is printed. Invisibly returns a dataframe of mean metric values
#'   (Pearson r^2, R^2, RMSE) averaged over subclasses.
#' @export
summary.tune_deconv <- function(object,
                                metric = attr(object, "metric"),
                                method = attr(object, "method"),
                                ...) {
  method <- match.arg(method, c("top", "overall"))
  metric <- match.arg(metric, met_params)
  metFUN <- if (metric %in% c("Rsq", "pearson.rsq")) which.max else which.min
  metcol <- if (metric %in% c("resvar", "kappa")) {metric
  } else paste0("mean.", metric)
  
  params <- colnames(object)
  params <- params[!params %in% c("subclass", met_params)]
  mres <- tune_stats(object, params)
  w <- metFUN(mres[, metcol])
  
  if (method == attr(object, "method") && metric == attr(object, "metric")) {
    best_tune <- attr(object, "tune")
  } else if (method == "top") {
    best_tune <- mres[w, ]
  } else {
    # overall mean
    best_tune <- lapply(params, function(i) {
      mres <- aggregate(object[, metric], by = object[, i, drop = FALSE],
                        FUN = mean, na.rm = TRUE)
      w <- metFUN(mres$x)
      mres[w, i]
    })
    best_tune <- data.frame(best_tune)
    colnames(best_tune) <- params
  }
  
  message("Best tune:")
  print(best_tune, row.names = FALSE, digits = max(3, getOption("digits") -3),
        print.gap = 2L)
  invisible(mres)
}


met_params <- c("pearson.rsq", "Rsq", "RMSE", "bias", "var", "resvar", "kappa")

tune_stats <- function(object, params) {
  mres <- aggregate(object[, met_params], by = object[, params, drop = FALSE],
                    FUN = mean, na.rm = TRUE)
  w <- which(colnames(mres) %in% met_params[1:3])
  colnames(mres)[w] <- paste0("mean.", met_params[1:3])
  mres
}


best_nsubclass <- function(object, metric = attr(object, "metric")) {
  if (!"nsubclass" %in% colnames(object)) stop("nsubclass not tuned")
  best_tune <- attr(object, "tune")
  wc <- which(!colnames(best_tune) %in% c("nsubclass", paste0("mean.", metric)))
  if (length(wc)) {
    ind <- lapply(wc, function(i) {
      wcol <- colnames(best_tune)[i]
      object[, wcol] == best_tune[[i]]
    })
    ind <- do.call(cbind, ind)
    w <- rowSums(ind) == length(wc)
    object <- object[w, ]
  }
  ret <- lapply(levels(object$subclass), function(i) {
    sub <- object[object$subclass == i, ]
    w <- if (metric %in% c("Rsq", "pearson.rsq")) {
      which.max(sub[, metric])
    } else which.min(sub[, metric])
    sub$nsubclass[w]
  })
  names(ret) <- levels(object$subclass)
  unlist(ret)
}


#' Plot tuning curves
#' 
#' Produces a ggplot2 plot of R-squared/RMSE values generated by
#' [tune_deconv()].
#' 
#' @param result Dataframe of tuning results generated by [tune_deconv()].
#' @param group Character value specifying column in `result` to be grouped by
#'   colour; or `NULL` to average R-squared/RMSE values across the grid and show
#'   the generalised mean effect of varying the parameter specified by `xvar`.
#' @param xvar Character value specifying column in `result` to vary along the x
#'   axis.
#' @param fix Optional list specifying parameters to be fixed at specific values.
#' @param metric Specifies tuning metric: either "RMSE", "Rsq", "pearson" or
#'   "resvar" (residual variance of bulk gene expression).
#' @param title Character value for the plot title.
#' @param show_legend Logical whether to show the legend when `group` is set to
#'   `"subclass"`. By default the legend is hidden if there are many subclasses.
#' @param errorbars Logical whether to show error bars.
#' @param show_points Logical whether to overlay points.
#' @returns ggplot2 scatter plot.
#' @details
#' If `group` is set to `"subclass"`, then the tuning parameter specified by
#' `xvar` is varied on the x axis. Any other tuning parameters (i.e. if 2 or
#' more have been tuned) are fixed to their best tuned values.
#' 
#' If `group` is set to a different column than `"subclass"`, then the mean
#' R-squared/RMSE values in `result` are averaged over subclasses. This makes it
#' easier to compare the overall effect (mean R-squared/RMSE) of 2 tuned
#' parameters which are specified by `xvar` and `group`. Any remaining
#' parameters not shown are fixed to their best tuned values.
#' 
#' If `group` is `NULL`, the tuning parameter specified by `xvar` is varied on
#' the x axis and R-squared/RMSE values are averaged over the whole grid to give
#' the generalised mean effect of varying the `xvar` parameter.
#' @importFrom dplyr near
#' @importFrom ggplot2 geom_line ggtitle mean_se stat_summary theme_bw labs
#' @export
plot_tune <- function(result, group = "subclass", xvar = colnames(result)[1],
                      fix = NULL,
                      metric = attr(result, "metric"),
                      title = NULL,
                      show_legend = nlevels(result$subclass) < 25,
                      errorbars = TRUE,
                      show_points = TRUE) {
  params <- colnames(result)
  params <- params[!params %in% c("subclass", met_params)]
  if (!xvar %in% params) stop("incorrect `xvar`")
  metric <- match.arg(metric, met_params)
  
  if (is.null(group)) {
    xdiff <- diff(range(result[, xvar], na.rm = TRUE))
    
    p <- ggplot(result, aes(x = .data[[xvar]], y = .data[[metric]])) +
      stat_summary(fun = mean, geom = "line", col = "limegreen") +
      (if (errorbars) {
        stat_summary(fun.data = mean_se, geom = "errorbar", col = "black",
                     width = 0.02 * xdiff)}) +
      (if (errorbars) {
        stat_summary(fun = mean, geom = "point", col = "black")}) +
      ggtitle(title) +
      theme_bw() +
      theme(plot.title = element_text(size = 9),
            axis.text = element_text(colour = "black"),
            axis.ticks = element_line(colour = "black"))
    return(p)
  }
  if (!group %in% colnames(result)) stop("incorrect `group`")
  by_params <- c(group, xvar)
  fix_params <- params[!params %in% c(by_params, "ngene")]
  if ("ngene" %in% by_params) fix_params <- fix_params[fix_params != "nsubclass"]
  
  mres <- tune_stats(result, params)
  metcol <- if (metric %in% c("Rsq", "pearson.rsq", "RMSE")) {
    paste0("mean.", metric)
  } else metric
  metFUN <- if (metric %in% c("Rsq", "pearson.rsq")) which.max else which.min
  w <- metFUN(mres[, metcol])
  best_tune <- mres[w, ]
  
  # custom fix params
  if (!is.null(fix)) {
    if (!all(names(fix) %in% colnames(best_tune))) stop("unable to fix parameter")
    for (i in seq_along(fix)) {
      
      if (length(fix[i]) != 1 || !fix[i] %in% unique(mres[, names(fix)[i]])) {
        stop("unable to fix parameter level")}
      best_tune[names(fix)[i]] <- fix[i]
    }
  }
  
  if (group == "subclass") {
    # usual plot
    if (length(fix_params)) {
      # 2 or more params tuned, fix using best_tune
      fix <- lapply(fix_params, function(i) {
        if (is.character(best_tune[, i])) return(result[, i] == best_tune[, i])
        near(result[, i], best_tune[, i])
      })
      p <- paste(paste(fix_params, best_tune[, fix_params], sep = " = "),
                 collapse = ", ")
      message("Fix ", p)
      if (is.null(title)) title <- p
      fix <- do.call(cbind, fix)
      if (ncol(fix) > 1) fix <- rowSums(fix) == ncol(fix)
      result <- result[fix, ]
    }
    xdiff <- diff(range(result[, xvar], na.rm = TRUE))
    
    ggplot(result, aes(x = .data[[xvar]], y = .data[[metric]],
                       color = .data[[group]])) +
      geom_line() +
      (if (show_points) geom_point()) +
      stat_summary(fun.data = mean_se, geom = "errorbar", col = "black",
                   width = 0.02 * xdiff) +
      stat_summary(fun = mean, geom = "point", col = "black") +
      labs(color = "") +
      ggtitle(title) +
      theme_bw() +
      theme(plot.title = element_text(size = 9),
            axis.text = element_text(colour = "black"),
            axis.ticks = element_line(colour = "black"),
            legend.key.size = unit(0.8, 'lines'),
            legend.spacing.y = unit(0, 'lines')) +
      (if (!show_legend) theme(legend.position = "none"))
  } else {
    # mean Rsq over subclasses
    if (length(fix_params)) {
      # 3 or more params tuned, fix using best_tune
      fix <- lapply(fix_params, function(i) {
        if (is.character(best_tune[, i])) return(mres[, i] == best_tune[, i])
        near(mres[, i], best_tune[, i])
      })
      p <- paste(paste(fix_params, best_tune[, fix_params], sep = " = "),
                 collapse = ", ")
      message("Fix ", p)
      if (is.null(title)) title <- p
      fix <- do.call(cbind, fix)
      if (ncol(fix) > 1) fix <- rowSums(fix) == ncol(fix)
      mres <- mres[fix, ]
    }
    mres[, group] <- factor(mres[, group])
    ggplot(mres, aes(x = .data[[xvar]], y = .data[[metcol]],
                       color = .data[[group]])) +
      geom_line() +
      (if (show_points) geom_point()) +
      ggtitle(title) +
      theme_bw() +
      theme(plot.title = element_text(size = 9),
            axis.text = element_text(colour = "black"),
            axis.ticks = element_line(colour = "black"),
            legend.key.size = unit(0.8, 'lines'),
            legend.spacing.y = unit(0, 'lines'))
  }
}


#' Plot bias-variance decomposition
#' 
#' Produces a ggplot2 plot of total error (MSE) against bias and variance based
#' on results generated by [tune_deconv()]. Primarily useful for examining the
#' effect of altering compensation or lambda.
#' 
#' @param result Dataframe of tuning results generated by [tune_deconv()].
#' @param subclass Character value specifying which cell subclass to plot, or
#'   numeric selecting the subclass by factor level.
#' @param xvar Character value specifying column in `result` to vary along the x
#'   axis.
#' @returns ggplot2 line plot.
#' @importFrom ggplot2 scale_colour_discrete
#' @export
plot_biasvar <- function(result, subclass = 1,
                         xvar = colnames(result)[1]) {
  if (!requireNamespace("tidyr", quietly = TRUE))
    stop("Package 'tidyr' is not installed", call. = FALSE)
  params <- colnames(result)
  params <- params[!params %in% c("subclass", "ngene", met_params)]
  if (length(params) > 1) stop("multiple parameters")
  
  if (is.numeric(subclass)) subclass <- levels(result$subclass)[subclass]
  res <- result[result$subclass == subclass, ]
  res$MSE <- res$RMSE^2
  res$`bias^2` <- res$bias^2
  key <- c("MSE", "bias^2", "var")
  sres <- res[, c(xvar, key)]
  dat <- sres |> tidyr::pivot_longer(key)
  dat$name <- factor(dat$name, levels = key)
  
  ggplot(dat, aes(x = .data[[xvar]], y = .data[["value"]],
                  colour = .data[["name"]],
                  linetype = .data[["name"]])) +
    geom_line() +
    scale_colour_discrete(type = c("black", "red", "blue")) +
    ylab("MSE / bias / variance") +
    labs(color = "", linetype = "") +
    ggtitle(subclass) +
    theme_classic() +
    theme(plot.title = element_text(size = 9),
          axis.text = element_text(colour = "black"))
}

Try the cellGeometry package in your browser

Any scripts or data that you put into this service are public.

cellGeometry documentation built on April 20, 2026, 1:06 a.m.