R/kmer_logistic_regression.R

Defines functions plot_wrapper inverse_logit plot_trinuc_logistic_regression plot_kmer_logistic_regression kmer_logistic_regression

Documented in inverse_logit kmer_logistic_regression plot_kmer_logistic_regression plot_trinuc_logistic_regression

#' Fit regularized multinomial regression model
#'
#' Fit a multinomal regression model on encoded kmer-data
#'
#' @param x Dataset, including mutations and pyrimidine based sequence contexts
#' @param kmers String vector. Kmers to include when fitting signature
#' @param include_fit Bool. Include the fitted `glmnet` object in output?
#' @param include_encoded Bool. Include the input data encoded to binary matrix?
#' @param ... Additional arguments to `glmnet`
#' @export
kmer_logistic_regression = function(x, kmers, include_fit = FALSE, include_encoded = TRUE, ...) {
  aux = function(z) {
    var_names = colnames(z$contexts)
    n_vars    = ncol(z$contexts)
    trinucs   = var_names[nchar(var_names) == 3]
    n_trinucs = length(trinucs)
    n_kmers   = n_vars - n_trinucs
    
    glmnet_res = glmnet::glmnet(x                = z$contexts,
                                y                = z$mutations,
                                family           = 'multinomial',
                                intercept        = FALSE,
                                penalty.factor   = c(rep(0, n_trinucs), rep(1, n_kmers)),
                                trace.it         = 0,
                                alpha            = 1,
                                standardize      = F,
                                ...)
    
    glmnet_coeffs = {coeffs = lapply(coef(glmnet_res), as.matrix); abind::abind(coeffs, along=3)}
    glmnet_coeffs = glmnet_coeffs[2:nrow(glmnet_coeffs),,] # First row is intercept
                                                           # which is forced to zero
    
    return(list(coeffs = glmnet_coeffs, fit = glmnet_res))
  }
  
  enc   = encode_seqs(x, kmers)
  x_res = on_pyrs(enc, aux)
  ret_list = list(
    res   = on_pyrs(x_res, function(x) x[['coeffs']]),
    enc   = enc,
    fit   = on_pyrs(x_res, function(x) x[['fit']])
  )
  
  return(ret_list[c(T,include_encoded, include_fit)])
}


#' Plot results from a fitted `kmer_logistic_regression`
#' 
#' Generate a plot to overview the different coefficients in the model
#' @param kmer_model The result from a `kmer_logistic_regression` call
#' @param plot_modifier A function that takes and returns a ggplot object
#' @return A list of ggplot objects
#' @export
plot_kmer_logistic_regression = function(kmer_model, plot_modifier = function(x) x) {
  # Note that `glmnet` will produce coeffs for the reference class 
  # (NA; no mutation) as well as the other categories (C>A, C>G, etc...).
  # This behavior is (I believe) package dependent. It's not what you
  # would get from a usual binomial logistic regression.
  # In order to have NA as a reference point for the other coefficients,
  # the NA-coefficient is subtracted from all other coeffs below. After this
  # we can run a `inverse_logit` on the resulting coefficients to transform
  # values from log-odds-space to probability-space. I have tested this
  # manually to verify that we do get correct probabilities after
  # inverse-log transformation. The values are off by a very small margin
  # that most likely is due to the parameter estimation itself and/or 
  # rounding errors, not the transformation.
  
  # Kmer correction:
  kmer_correcter = function(x, ref_point = 'NA') {
    x.dim      = dim(x)
    x.dn       = dimnames(x)
    ref_column = x.dn[[3]] == ref_point
    
    ret        = x[,,!ref_column]
    for (i in 1:sum(!ref_column)) {
      mtx = ret[,,i]
      ix = mtx != 0
      mtx[mtx != 0] = mtx[ix] - x[,,ref_column][ix]
      ret[,,i] = mtx
    }
    ret
  }
  
  kmer_model_corrected = {
    ret = kmer_model
    ret$res = on_pyrs(kmer_model$res, kmer_correcter)
    ret
  }
  
  aes = ggplot2::aes
  `%>%` = tidyr::`%>%`
  
  plot_aux = function(dat) {
    ret = dat %>%
      dplyr::filter(nchar(kmer) != 3) %>%
      ggplot2::ggplot(aes(x = mut_class, y = kmer, fill = model_coeff)) +
      ggplot2::geom_tile() +
      ggplot2::geom_text(aes(label = round(model_coeff, 3)))
    
    plot_modifier(ret)
  }
  plot_wrapper(kmer_model_corrected, plot_aux)
}

#' Plot results from a fitted `kmer_logistic_regression`
#' 
#' Generate a plot to overview the trinucleotide part of a kmer model
#' @param kmer_model The result from a `kmer_logistic_regression` call
#' @return A ggplot object
#' @export
plot_trinuc_logistic_regression = function(kmer_model) {
  aes = ggplot2::aes
  `%>%` = tidyr::`%>%`
  plot_aux = function(dat) {
    dat %>% 
      dplyr::filter(nchar(kmer) == 3) %>% 
      dplyr::mutate(prob = inverse_logit(model_coeff)) %>% 
      ggplot2::ggplot(aes(x = kmer, y = prob, fill = mut_class)) +
      ggplot2::geom_col() +
      ggplot2::facet_grid(.~mut_class, scales = 'free_x') +
      ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 90, vjust = .5))
  }
  # Doesn't matter what level of regularization is used when displaying
  # trinucleotides, as the coefficients for trinucleotides aren't regularized
  plot_wrapper(kmer_model, plot_aux)[[1]] 
}

#' Inverse logit
#' 
#' Transform log-odds to probabilities
#' @param x Numeric. Log-odds value
#' @return A probability value
inverse_logit = function(x) {
  ex = exp(x); ex / (1 + ex)
}

plot_wrapper = function(kmer_model, plotter) {
  res = kmer_model$res
  regularization_levels    = unlist(on_pyrs(res, ncol))
  regularization_midpoints = ceiling(regularization_levels / 2)
  
  subset_aux = function(reg_lvl = c("min", "mid", "max")) {
    rl = reg_lvl[1]
    if (rl == 'min') {
      j = c(1,1)
    } else if (rl == 'mid') {
      j = regularization_midpoints
    } else if (rl == 'max') {
      j = regularization_levels
    } else {
      log_error("Internal error: Invalid argument to plot_kmer_logistic_regression->aux")
    }
    res = dplyr::as_tibble(cbind(res[[1]][,j[1],], res[[2]][,j[2],]), rownames = 'kmer')
    tidyr::pivot_longer(res, 2:ncol(res), names_to = 'mut_class', values_to = 'model_coeff')
  }
  
  return(list(
    'min' = plotter(subset_aux('min')),
    'mid' = plotter(subset_aux('mid')),
    'max' = plotter(subset_aux('max'))
  ))
}
lindberg-m/contextendR documentation built on Jan. 8, 2022, 3:16 a.m.