R/calculate_PC1.R

Defines functions calculate_PC1

Documented in calculate_PC1

#' Compute PC1 scores for a chemical language model
#'
#' The main function of the \code{CLMeval} package, \code{calculate_PC1} allows 
#' the user to evaluate a chemical language model by integrating five orthogonal
#' metrics of model performance. This is accomplished by principal component 
#' analysis of a dataset where the major dimension of variance is model 
#' performance (that is, models segregate along the first principal component
#' based on their ability to match the chemical space of the training set).
#' This function performs PCA in a reference matrix of chemical outcomes, then 
#' uses the base R \code{\link{predict}} function to project a model of interest
#' onto the same principal components.
#' 
#' The function takes as input five metrics that reflect the quality of a 
#' chemical language model. These metrics were chosen because they were found to
#' be robustly correlated to the number of molecules in the training set across
#' a series of benchmarking analyses. These five metrics are as follows:
#' 
#' \enumerate{
#'   \item the proportion of valid molecules generated by the trained model
#'   \item the Frechet ChemNet distance to the training set 
#'   \item the Jensen-Shannon distance between the number of
#'     stereocenters in molecules sampled from the trained model vs. the
#'     training set
#'   \item the Jensen-Shannon distance between the frequency 
#'     distribution of Murcko scaffolds within molecules sampled from the trained 
#'     model vs. the training set
#'   \item the Jensen-Shannon distance between the natural 
#'     product-likeness scores of molecules sampled from the trained molecule 
#'     vs. the training set
#' }
#' 
#' The reference matrix used to perform PCA contains metrics for a total of 440 
#' chemical language models. These were obtained by training recurrent neural 
#' network-based models on SMILES strings from the ChEMBL, COCONUT, GDB, and 
#' ZINC databases. The number of models from each database varied between 1,000
#' and 500,000, in eleven increments, and ten random samples of each size
#' were drawn from each database. For further details, see the 
#' \code{\link{reference}} documentation.
#' 
#' For futher details on the metrics, please find a complete description of the 
#' analysis at [doi:10.26434/chemrxiv.13638347.v1](https://doi.org/10.26434/chemrxiv.13638347.v1). 
#'
#' @param pct_valid the proportion of valid molecules generated by the trained
#'   model
#' @param FCD the Frechet ChemNet distance to the training set 
#' @param JSD_stereocenters the Jensen-Shannon distance between the number of
#'   stereocenters in molecules sampled from the trained model vs. the
#'   training set
#' @param JSD_murcko the Jensen-Shannon distance between the frequency 
#'   distribution of Murcko scaffolds within molecules sampled from the trained 
#'   model vs. the training set
#' @param JSD_NP the Jensen-Shannon distance between the natural 
#'   product-likeness scores of molecules sampled from the trained molecule 
#'   vs. the training set
#'
#' @return a scalar value representing the model's PC1 score, derived from the
#'   integration of all five metrics
#'
#' @importFrom stats princomp predict
#' @importFrom utils data
#' @importFrom magrittr set_colnames %>%
#'
#' @export
calculate_PC1 = function(pct_valid,
                         FCD,
                         JSD_stereocenters,
                         JSD_murcko,
                         JSD_NP) {
  # check inputs
  if (!is.numeric(pct_valid)) stop("pct_valid is not numeric")
  if (!is.numeric(FCD)) stop("FCD is not numeric")
  if (!is.numeric(JSD_stereocenters)) stop("JSD_stereocenters is not numeric")
  if (!is.numeric(JSD_murcko)) stop("JSD_murcko is not numeric")
  if (!is.numeric(JSD_NP)) stop("JSD_NP is not numeric")
  if (pct_valid < 0 || pct_valid > 1) stop("pct_valid not in range [0, 1]")
  if (FCD < 0) stop("FCD cannot be negative")
  if (JSD_stereocenters < 0) stop("JSD_stereocenters cannot be negative")
  if (JSD_murcko < 0) stop("JSD_murcko cannot be negative")
  if (JSD_NP < 0) stop("JSD_NP cannot be negative")
  
  # load matrix
  data(reference, envir = environment())
  
  # do PCA in the reference data
  pca = princomp(reference, cor = TRUE)
  
  # set up query
  query = matrix(c(pct_valid, FCD, JSD_stereocenters, JSD_murcko, JSD_NP,
                   ## add a second row to avoid drop
                   rep(NA, 5)),
                 nrow = 2, byrow = TRUE) %>% 
    set_colnames(colnames(reference))
  
  # predict
  pred = predict(pca, newdata = query)
  
  # pull out PC1
  return(pred[1, 'Comp.1'])
}
skinnider/CLMeval documentation built on Dec. 23, 2021, 3:23 a.m.