R/velocity.R

Defines functions add_velocity get_velocity check_scvelo

Documented in add_velocity get_velocity

#' Add velocity to a dynwrap dataset
#'
#' @inheritParams dynwrap::add_trajectory
#' @inheritParams get_velocity
#' @param velocity The velocity object as generated by [get_velocity()].
#'
#' @importFrom dynutils add_class
#'
#' @export
add_velocity <- function(
  dataset,
  spliced = dataset$expression,
  unspliced = dataset$expression_unspliced,
  mode = c("stochastic", "deterministic", "dynamical"),
  n_neighbors = 20L,
  velocity = get_velocity(spliced = spliced, unspliced = unspliced, mode = mode, n_neighbors = n_neighbors)
) {
  mode <- match.arg(mode)
  assert_that(!is.null(velocity$expression_future))

  dataset$expression_future <- velocity$expression_future
  velocity$expression_future <- NULL
  dataset$velocity <- velocity

  dynutils::add_class(dataset, "wrapper_with_velocity")

  dataset
}


#' Calculate velocity
#'
#' @param spliced Spliced expression matrix
#' @param unspliced Unspliced expression matrix
#' @param var_names Names of variables/genes to use for the fitting. Can be `"velocity_genes"`, `"all"`, or a set of gene names.
#' @param mode Whether to run the estimation using the deterministic or stochastic model
#'  of transcriptional dynamics.
#' @param n_neighbors Number of neighbors to use.
#'
#' @importFrom methods as
#'
#' @export
get_velocity <- function(
  spliced,
  unspliced,
  mode = c("stochastic", "deterministic", "dynamical"),
  n_neighbors = 20L,
  var_names = "velocity_genes",
  layer = "spliced"
) {
  # check inputs
  mode <- match.arg(mode)
  assert_that(
    all(dim(spliced) == dim(unspliced)),
    all(rownames(spliced) == rownames(unspliced)),
    all(colnames(spliced) == colnames(unspliced))
  )

  # create anndata object
  velocity <- anndata$AnnData(spliced)
  velocity$var_names <- colnames(spliced)
  velocity$obs_names <- rownames(spliced)

  py_assign(velocity$layers, "spliced", spliced)
  py_assign(velocity$layers, "unspliced", unspliced)

  # calculate velocity
  # py_capture_output({ # can't capture output because of https://github.com/rstudio/reticulate/issues/386, otherwise crash when testing

    scvelo$pp$moments(velocity, n_neighbors = n_neighbors)

    if (mode == "dynamical") {
      scvelo$tl$velocity(velocity, mode = "deterministic")
      scvelo$tl$velocity_graph(velocity)

      scvelo$tl$recover_dynamics(velocity, var_names = var_names)
    }
    scvelo$tl$velocity(velocity, mode = mode)
    scvelo$tl$velocity_graph(velocity)
  # })


  velocity_vector <- velocity$layers[["velocity"]]
  velocity_vector[is.na(velocity_vector)] <- 0

  if (layer == "imputed") {
    imputed <- velocity$layers[["Ms"]]
    expression_future <- imputed + velocity_vector
    dimnames(expression_future) <- dimnames(spliced)
  } else {
    expression_future <- spliced + velocity_vector
  }

  expression_future <- as(expression_future, "dgCMatrix")

  # get transition matrix
  py$x <- velocity
  py_run_string("import scvelo")
  transition_matrix <- py_to_r(py_eval("scvelo.tl.transition_matrix(x).tocsc()"))
  colnames(transition_matrix) <- rownames(spliced)
  rownames(transition_matrix) <- rownames(spliced)

  # return velocity object
  list(
    expression_future = expression_future,
    transition_matrix = transition_matrix,
    scvelo = velocity
  )
}

# Checks if the scvelo object is still present
# this is not the case after saving the R object as an rds
check_scvelo <- function(scvelo) {
  if(is.null(scvelo)) {
    FALSE
  } else if(as.character(scvelo) == "<pointer: 0x0>") {
    FALSE
  } else {
    TRUE
  }
}
dynverse/scvelo documentation built on April 9, 2020, 3:42 a.m.