R/FitGLMModel.R

Defines functions getCV modelTypeToCyclopsModelType fitGLMModel

Documented in fitGLMModel

# @file fitGLMModel.R
#
# Copyright 2020 Observational Health Data Sciences and Informatics
#
# This file is part of PatientLevelPrediction
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Fit a predictive model
#'
#' @param population   A population object generated by \code{createStudyPopulation}, potentially filtered by other functions.
#'
#' @param plpData       An object of type \code{plpData} as generated using
#'                               \code{getDbPlpData}.
#' @param modelType              The type of outcome model that will be used. Possible values are
#'                               "logistic", "poisson", or "cox".
#' @param excludeCovariateIds    Exclude these covariates from the outcome model.
#' @param includeCovariateIds    Include only these covariates in the outcome model.
#' @param prior                  The prior used to fit the model. See
#'                               \code{\link[Cyclops]{createPrior}} for details.
#' @param control                The control object used to control the cross-validation used to
#'                               determine the hyperparameters of the prior (if applicable). See
#'                               \code{\link[Cyclops]{createControl}} for details.
#'
#' @export
fitGLMModel <- function(population,
                            plpData,
                            modelType = "logistic",
                            excludeCovariateIds = c(),
                            includeCovariateIds = c(),
                            prior = createPrior("laplace", useCrossValidation = TRUE),
                            control = createControl(cvType = "auto",
                                                    fold=3,
                                                    startingVariance = 0.01,
                                                    tolerance  = 2e-06,
                                                    cvRepetitions = 1,
                                                    selectorType = "byPid",
                                                    noiseLevel = "silent",
                                                    threads=-1,
                                                    maxIterations = 3000)) {

  
  start <- Sys.time()
  if (nrow(population) == 0) {
    status <- "NO SUBJECTS IN POPULATION, CANNOT FIT"
  } else if (sum(population$outcomeCount) == 0) {
    status <- "NO OUTCOMES FOUND FOR POPULATION, CANNOT FIT"
  }  else {
    colnames(population)[colnames(population) == "outcomeCount"] <- "y"
    
    covariateData <- limitCovariatesToPopulation(plpData$covariateData, population$rowId)
    
    # exclude or include covariates
    if ( (length(includeCovariateIds) != 0) & (length(excludeCovariateIds) != 0)) {
      covariates <- covariateData$covariates %>% dplyr::filter(covariateId %in%includeCovariateIds) %>% dplyr::filter(!covariateId %in%excludeCovariateIds)
    } else if ( (length(includeCovariateIds) == 0) & (length(excludeCovariateIds) != 0)) { 
      covariates <- covariateData$covariates %>% dplyr::filter(!covariateId %in%excludeCovariateIds)
    } else if ( (length(includeCovariateIds) != 0) & (length(excludeCovariateIds) == 0)) {
      covariates <- covariateData$covariates %>% dplyr::filter(covariateId %in%includeCovariateIds)
    } else {
      covariates <- covariateData$covariates
    }
    
    if (modelType == "cox"){
      population$y[population$y != 0] <- 1
      population$time <- population$survivalTime
    } else if (modelType == "logistic"){
      population$y[population$y != 0] <- 1
    } else if (modelType == "poisson"){
      population$time <- population$timeAtRisk
    }
    if (modelType == "cox") {
      addIntercept <- FALSE
    } else {
      addIntercept <- TRUE
    }
    
    covariateData$andromedaPopulation <- population[,!colnames(population)%in%c('cohortStartDate')]
    
    cyclopsData <- Cyclops::convertToCyclopsData(outcomes = covariateData$andromedaPopulation,
                                                 covariates = covariates,
                                                 addIntercept = addIntercept,
                                                 modelType = modelTypeToCyclopsModelType(modelType),
                                                 #checkSorting = TRUE,
                                                 checkRowIds = FALSE,
                                                 normalize = NULL,
                                                 quiet = TRUE)
  if(prior$useCrossValidation){
    fit <- tryCatch({
      ParallelLogger::logInfo('Running Cyclops')
      Cyclops::fitCyclopsModel(cyclopsData, prior = prior, control = control)}, 
      finally = ParallelLogger::logInfo('Done.'))
  } else{
    fit <- tryCatch({
      ParallelLogger::logInfo('Running Cyclops with fixed varience')
      Cyclops::fitCyclopsModel(cyclopsData, prior = prior)}, 
      finally = ParallelLogger::logInfo('Done.')) 
    }
    if (is.character(fit)) {
      coefficients <- c(0)
       status <- fit
    } else if (fit$return_flag == "ILLCONDITIONED") {
      coefficients <- c(0)
        status <- "ILL CONDITIONED, CANNOT FIT"
        ParallelLogger::logWarn(paste("GLM fitting issue: ", status))
    } else if (fit$return_flag == "MAX_ITERATIONS") {
      coefficients <- c(0)
       status <- "REACHED MAXIMUM NUMBER OF ITERATIONS, CANNOT FIT"
       ParallelLogger::logWarn(paste("GLM fitting issue: ", status))
    } else {
      status <- "OK"
      coefficients <- stats::coef(fit) # not sure this is stats??
      ParallelLogger::logInfo(paste("GLM fit status: ", status))
       }
  }
  outcomeModel <- attr(population, "metaData")
  outcomeModel$coefficients <- coefficients
  #outcomeModel$outcomeModelPriorVariance <- priorVariance
  outcomeModel$priorVariance <- fit$variance
  outcomeModel$log_likelihood <- fit$log_likelihood
  outcomeModel$modelType <- modelType
  outcomeModel$modelStatus <- status
  outcomeModel$populationCounts <- getCounts(population, "Population count")
  if (modelType == "poisson" || modelType == "cox") {
    timeAtRisk <- data.frame(sum(population$timeAtRisk)) # not sure this is correct?
    outcomeModel$timeAtRisk <- timeAtRisk
  }
  class(outcomeModel) <- "plpModel"
  delta <- Sys.time() - start
  ParallelLogger::logInfo(paste("Fitting model took", signif(delta, 3), attr(delta, "units")))
  
  
  #get CV
  if(modelType == "logistic" && prior$useCrossValidation){
    outcomeModel$cv <- getCV(cyclopsData, population, cvVariance = fit$variance)
  }
  
  return(outcomeModel)
}

modelTypeToCyclopsModelType <- function(modelType, stratified=F) {
  if (modelType == "logistic") {
    if (stratified)
      return("clr")
    else
      return("lr")
  } else if (modelType == "poisson") {
    if (stratified)
      return("cpr")
    else
      return("pr")
  } else if (modelType == "cox") {
    return("cox")
  } else {
    ParallelLogger::logError(paste("Unknown model type:", modelType))
    stop()
  }

}



getCV <- function(cyclopsData, 
                  population,
                  cvVariance){
   fixed_prior <- createPrior("laplace", variance = cvVariance, useCrossValidation = FALSE)
  
  result <- lapply(1:max(population$indexes), function(i) {
    hold_out <- population$indexes==i
    weights <- rep(1.0, getNumberOfRows(cyclopsData))
    weights[hold_out] <- 0.0
    subset_fit <- suppressWarnings(fitCyclopsModel(cyclopsData,
                                  prior = fixed_prior,
                                  weights = weights))
    predict <- predict(subset_fit)
    
    auc <- aucWithoutCi(predict[hold_out], population$y[hold_out])
    
    predCV <- cbind(population[hold_out,c('rowId','indexes','y')], 
          value = predict[hold_out])
    predCV$outcomeCount <- predCV$y
    
    return(list(out_sample_auc = auc,
                predCV = predCV,
                log_likelihood = subset_fit$log_likelihood,
                log_prior = subset_fit$log_prior,
                coef = coef(subset_fit)))
  })
  
  
}
hxia/plp-git-demo documentation built on March 19, 2021, 1:54 a.m.