R/Recalibration.R

Defines functions inverseLog logFunct weakRecalibration recalibrationInTheLarge recalibratePlp recalibratePlpRefit

Documented in recalibratePlp recalibratePlpRefit

# @file Recalibration.R
#
# Copyright 2021 Observational Health Data Sciences and Informatics
#
# This file is part of PatientLevelPrediction
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


#' recalibratePlpRefit
#'
#' @description
#' Train various models using a default parameter gird search or user specified parameters
#'
#' @details
#' The user can define the machine learning model to train (regularised logistic regression, random forest,
#' gradient boosting machine, neural network and )
#' 
#' @param plpModel                         The trained plpModel (runPlp$model)
#' @param newPopulation                    The population created using createStudyPopulation() who will have their risks predicted
#' @param newData                          An object of type \code{plpData} - the patient level prediction
#'                                         data extracted from the CDM.
#' @return
#' An object of class \code{runPlp} that is recalibrated on the new data
#'
#' @export
recalibratePlpRefit <- function(
  plpModel,
  newPopulation, 
  newData
){
  if (is.null(newPopulation)){
    stop("NULL population")
  }
  if (!inherits(x = newData, what =  "plpData")){
    stop("Incorrect plpData class")
  }
  if (!inherits(x = plpModel, what = "plpModel")){
    stop("plpModel is not of class plpModel")
  }
  
  #get selected covariates
  includeCovariateIds <- plpModel$covariateImportance %>% 
    dplyr::filter(.data$covariateValue != 0) %>% 
    dplyr::select("covariateId") %>% 
    dplyr::pull()
  
  # check which covariates are included in new data
  containedIds <- newData$covariateData$covariateRef %>% dplyr::collect()
  noShrinkage <- intersect(includeCovariateIds, containedIds$covariateId)
  
  # add intercept
  noShrinkage <- append(noShrinkage, 0, 0)
  
  setLassoRefit <- setLassoLogisticRegression(
    includeCovariateIds = includeCovariateIds,
    noShrinkage = noShrinkage, 
    maxIterations = 10000 # increasing this due to test code often not converging
  )
  
  newData$labels <- newPopulation 
  
  newData$folds <- data.frame(
    rowId = newData$labels$rowId, 
    index = sample(2, length(newData$labels$rowId), replace = T)
    )
  
  # add dummy settings to fit model
  attr(newData, "metaData")$outcomeId <- attr(newPopulation, 'metaData')$outcomeId
  attr(newData, "metaData")$targetId <- attr(newPopulation, 'metaData')$targetId
  attr(newData, "metaData")$restrictPlpDataSettings <- attr(newPopulation, 'metaData')$restrictPlpDataSettings
  attr(newData, "metaData")$covariateSettings <- newData$metaData$covariateSettings
  attr(newData, "metaData")$populationSettings <- attr(newPopulation, 'metaData')$populationSettings
  attr(newData$covariateData, "metaData")$featureEngineeringSettings <- PatientLevelPrediction::createFeatureEngineeringSettings()
  attr(newData$covariateData, "metaData")$preprocessSettings <- PatientLevelPrediction::createPreprocessSettings()
  attr(newData, "metaData")$splitSettings <- PatientLevelPrediction::createDefaultSplitSetting()
  attr(newData, "metaData")$sampleSettings <- PatientLevelPrediction::createSampleSettings()
  
  newModel <- tryCatch({
    fitPlp(
      trainData = newData, 
      modelSettings = setLassoRefit,
      analysisId = 'recalibrationRefit',
      analysisPath = NULL
    )
  }, 
  error = function(e){ParallelLogger::logInfo(e); return(NULL)}
  )
  if(is.null(newModel)){
    ParallelLogger::logInfo('Recalibration fit failed')
    return(NULL)
  }
  
  newModel$prediction$evaluationType <- 'recalibrationRefit'

  oldPred <- predictPlp(
    plpModel = plpModel, 
    plpData = newData, 
    population = newPopulation, 
    timepoint = 0
    )
  
  oldPred$evaluationType <- 'validation'
  
  prediction <- rbind(
    oldPred, 
    newModel$prediction[, colnames(oldPred)]
    )

  if(!is.null(newModel$covariateImportance)){
    adjust <- newModel$covariateImportance %>% 
      dplyr::filter(.data$covariateValue != 0) %>% 
      dplyr::select(
        "covariateId", 
        "covariateValue"
      )
  } else{
    adjust <- c()
  }
  
  newIntercept <- newModel$model$coefficients[names(newModel$model$coefficients) == '(Intercept)']
  
  attr(prediction, "metaData")$recalibratePlpRefit <- list(adjust = adjust, newIntercept = newIntercept)
  
  return(prediction)
}


#' recalibratePlp
#'
#' @description
#' Train various models using a default parameter gird search or user specified parameters
#'
#' @details
#' The user can define the machine learning model to train (regularised logistic regression, random forest,
#' gradient boosting machine, neural network and )
#' 
#' @param prediction                      A prediction dataframe
#' @param analysisId                      The model analysisId
#' @param typeColumn                      The column name where the strata types are specified
#' @param method                          Method used to recalibrate ('recalibrationInTheLarge' or 'weakRecalibration' )
#' @return
#' An object of class \code{runPlp} that is recalibrated on the new data
#'

#' @export
recalibratePlp <- function(prediction, analysisId, typeColumn = 'evaluationType',
                           method = c('recalibrationInTheLarge', 'weakRecalibration')){
  # check input:
    if (!inherits(x = prediction, what =  'data.frame')){
      stop("Incorrect prediction") 
    }
  
  if(!method  %in% c('recalibrationInTheLarge', 'weakRecalibration')){
    stop("Unknown recalibration method type. must be of type: recalibrationInTheLarge, weakRecalibration")
  }
  
  prediction <- do.call(method, list(prediction = prediction, columnType = typeColumn))
  
  return(prediction)
  
}



recalibrationInTheLarge <- function(prediction, columnType = 'evaluationType'){
  
  if(attr(prediction, "metaData")$modelType == 'binary'){
    misCal <- calibrationInLarge(prediction)
    obsOdds <- misCal$observedRisk/ (1-misCal$observedRisk)
    predOdds <- misCal$meanPredictionRisk/ (1 -  misCal$meanPredictionRisk)
    correctionFactor <- log(obsOdds / predOdds)
    
    recalibrated <- prediction
    recalibrated$value = logFunct(inverseLog(recalibrated$value) + correctionFactor)
    
    recalibrated[,columnType] <- 'recalibrationInTheLarge'
    prediction <- rbind(prediction, recalibrated)
    attr(prediction, 'metaData')$recalibrationInTheLarge = list(correctionFactor = correctionFactor)
    
    return(prediction)
  }
  
  if(attr(prediction, "metaData")$modelType == 'survival'){

    ParallelLogger::logError('Survival recal in the large not currently available')
  }
  
  
}


weakRecalibration <- function(prediction, columnType = 'evaluationType'){
  
  # if binary:
  if(attr(prediction, "metaData")$modelType == 'binary'){
    recalibrated <- prediction
    recalibrated$value[recalibrated$value==0] <- 0.000000000000001
    recalibrated$value[recalibrated$value==1] <- 1-0.000000000000001
    
    y <- ifelse(recalibrated$outcomeCount>0, 1, 0)
    inverseLog <- inverseLog(recalibrated$value)
    refit <- suppressWarnings(stats::glm(y ~ inverseLog, family = 'binomial'))
    
    recalibrated$value <- logFunct((inverseLog * refit$coefficients[2]) + refit$coefficients[1])
    
    recalibrated[,columnType] <- 'weakRecalibration'
    prediction <- rbind(prediction, recalibrated)
    attr(prediction, 'metaData')$weakRecalibration = list(
      adjustGradient = refit$coefficients[2], 
      adjustIntercept = refit$coefficients[1]
      )
    
    return(prediction)
  } 
  
  # add if survival
  if(attr(prediction, "metaData")$modelType == 'survival'){
    
    recalibrated <- prediction
    
    # this will make the recalibration work if the baselineSurvival is missing
    baseline <- ifelse(is.null(attr(recalibrated, "baselineSurvival")), 0.9, attr(recalibrated, "baselineSurvival"))
    ParallelLogger::logInfo(paste0('recal initial baseline hazard: ',baseline))
    
    offset <- ifelse(is.null(attr(recalibrated, "offset")), 0, attr(recalibrated, "offset"))
    ParallelLogger::logInfo(paste0('recal initial offset: ',offset))
    
    timepoint <- ifelse(is.null(attr(recalibrated, "timePoint")), 365, attr(recalibrated, "timePoint"))
    ParallelLogger::logInfo(paste0('recal initial timepoint: ',timepoint))
    
    if(!is.null(baseline)){
      lp <- log(log(1-recalibrated$value)/log(baseline)) + offset
    } else{
      lp <- log(recalibrated$value)
    }
    
    
    t <- apply(cbind(recalibrated$daysToCohortEnd, recalibrated$survivalTime), 1, min)
    y <- ifelse(recalibrated$outcomeCount>0,1,0)  # observed outcome
    y[t>timepoint] <- 0
    t[t>timepoint] <- timepoint
    S<- survival::Surv(t, y) 
    #### Intercept + Slope recalibration
    f.slope <- survival::coxph(S~lp)
    h.slope <- max(survival::basehaz(f.slope)$hazard)  # maximum OK because of prediction_horizon
    lp.slope <- stats::predict(f.slope)
    recalibrated$value <- 1-exp(-h.slope*exp(lp.slope))
    # 1-h.slope^exp(lp.slope)
    
    
    recalibrated[,columnType] <- 'weakRecalibration'
    prediction <- rbind(prediction, recalibrated)
    attr(prediction, 'metaData')$weakRecalibration = list(
      adjustGradient = f.slope$coefficients['lp'], 
      adjustIntercept = h.slope
    )
    
    return(prediction)
    
  } 
  
} 

logFunct <- function(values){
  return(1/(1 + exp(0 - values)))
}

inverseLog <- function(values){
  res <- log(values/(1-values))
  return(res)
}
OHDSI/PatientLevelPrediction documentation built on April 27, 2024, 8:11 p.m.