R/trainTree.R

Defines functions trainTree

Documented in trainTree

#' trainTree Function
#'
#' This function trains a decision tree model based on patient data, which can either be gene expression levels or a binary matrix indicating mutations.
#'
#' @param PatientData A matrix representing patient features, where rows correspond to patients/samples
#'                    and columns correspond to genes/features. This matrix can contain:
#'                    \itemize{
#'                      \item Binary mutation data (e.g., presence/absence of mutations).
#'                      \item Continuous data from gene expression profiles (e.g., expression levels).
#'                    }
#' @param PatientSensitivity A matrix representing drug response values, where rows correspond to patients
#'                           in the same order as in `PatientData`, and columns correspond to drugs. Higher values indicate greater drug resistance and, consequently, 
#'                           lower sensitivity to treatment. This matrix can represent various measures of drug 
#'                           response, such as IC50 values or area under the drug response curve (AUC). Depending 
#'                           on the interpretation of these values, users may need to adjust the sign of this data.
#' @param minbucket An integer specifying the minimum number of patients required in a node to allow for a split.
#'
#' @return An object of class 'party' representing the trained decision tree, with the assigned treatments for each node.
#'
#' @examples
#'
#' \donttest{
#'   # Basic example of using the trainTree function with mutational data
#'   data("drug_response_w12")
#'   data("mutations_w12")
#'   ODTmut <- trainTree(PatientData = mutations_w12, 
#'                       PatientSensitivity = drug_response_w12,
#'                       minbucket = 10)
#'   plot(ODTmut)
#'
#'   # Example using gene expression data instead
#'   data("drug_response_w34")
#'   data("expression_w34")
#'   ODTExp <- trainTree(PatientData = expression_w34,
#'                       PatientSensitivity = drug_response_w34,
#'                       minbucket = 20)
#'   plot(ODTExp)
#' }
#'
#' @import matrixStats
#' @import partykit
#' @export
trainTree <- function(PatientData,
                      PatientSensitivity,
                      minbucket = 20) {
  # Check if the PatientData matrix is binary (i.e., contains only two unique values)
  if (length(unique(c(unlist(PatientData)))) == 2) {
    # For binary data (mutational data):
    
    # Transform the PatientData to start from 1
    PatientData <- PatientData - min(PatientData) + 1L
    # Set the mode of PatientData to integer for compatibility with the tree-growing function
    mode(PatientData) <- "integer"
    
    # Grow the decision tree using the mutational data
    nodes <- growtreeMut(id = 1L,
                         PatientSensitivity,
                         PatientData,
                         minbucket = minbucket)
  } else {
    # For non-binary data (gene expression data):
    
    # Grow the decision tree using the gene expression data
    nodes <- growtreeExp(id = 1L,
                         PatientSensitivity,
                         PatientData,
                         minbucket = minbucket)
  }
  
  # Create a party object from the nodes generated by the tree-growing functions
  tree <- party(nodes, data = as.data.frame(PatientData))
  
  # Remove additional node information from the tree object
  tree$node$info <- NULL
  
  # Return the trained decision tree
  return(tree)
}

Try the ODT package in your browser

Any scripts or data that you put into this service are public.

ODT documentation built on Oct. 18, 2024, 5:12 p.m.