R/SaveLoadPlp.R

Defines functions loadPlpResult savePlpResult loadPrediction savePrediction updateModelLocation loadPlpModel moveHdModel savePlpModel print.summary.plpData summary.plpData print.plpData loadPlpData savePlpData getPlpData

Documented in getPlpData loadPlpData loadPlpModel loadPlpResult loadPrediction savePlpData savePlpModel savePlpResult savePrediction

# @file PlpSaveLoad.R
#
# Copyright 2020 Observational Health Data Sciences and Informatics
#
# This file is part of CohortMethod
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#' Get the patient level prediction data from the server
#' @description
#' This function executes a large set of SQL statements against the database in OMOP CDM format to
#' extract the data needed to perform the analysis.
#'
#' @details
#' Based on the arguments, the at risk cohort data is retrieved, as well as outcomes
#' occurring in these subjects. The at risk cohort is identified  through
#' user-defined cohorts in a cohort table either inside the CDM instance or in a separate schema.
#' Similarly, outcomes are identified 
#' through user-defined cohorts in a cohort table either inside the CDM instance or in a separate
#' schema. Covariates are automatically extracted from the appropriate tables within the CDM.
#' If you wish to exclude concepts from covariates you will need to
#' manually add the concept_ids and descendants to the \code{excludedCovariateConceptIds} of the
#' \code{covariateSettings} argument.
#'
#' @param connectionDetails            An R object of type\cr\code{connectionDetails} created using the
#'                                     function \code{createConnectionDetails} in the
#'                                     \code{DatabaseConnector} package.
#' @param cdmDatabaseSchema            The name of the database schema that contains the OMOP CDM
#'                                     instance.  Requires read permissions to this database. On SQL
#'                                     Server, this should specifiy both the database and the schema,
#'                                     so for example 'cdm_instance.dbo'.
#' @param oracleTempSchema             For Oracle only: the name of the database schema where you want
#'                                     all temporary tables to be managed. Requires create/insert
#'                                     permissions to this database.
#' @param cohortId                     A unique identifier to define the at risk cohort. CohortId is
#'                                     used to select the cohort_concept_id in the cohort-like table.
#' @param outcomeIds                   A list of cohort_definition_ids used to define outcomes (-999 mean no outcome gets downloaded).
#' @param studyStartDate               A calendar date specifying the minimum date that a cohort index
#'                                     date can appear. Date format is 'yyyymmdd'.
#' @param studyEndDate                 A calendar date specifying the maximum date that a cohort index
#'                                     date can appear. Date format is 'yyyymmdd'. Important: the study
#'                                     end data is also used to truncate risk windows, meaning no outcomes
#'                                     beyond the study end date will be considered.
#' @param cohortDatabaseSchema         The name of the database schema that is the location where the
#'                                     cohort data used to define the at risk cohort is available.
#'                                     Requires read permissions to this database.
#' @param cohortTable                  The tablename that contains the at risk cohort.  cohortTable has
#'                                     format of COHORT table: cohort_concept_id, SUBJECT_ID,
#'                                     COHORT_START_DATE, COHORT_END_DATE.
#' @param outcomeDatabaseSchema            The name of the database schema that is the location where
#'                                         the data used to define the outcome cohorts is available. 
#'                                         Requires read permissions to this database.
#' @param outcomeTable                     The tablename that contains the outcome cohorts. Expectation is
#'                                         outcomeTable has format of COHORT table:
#'                                         COHORT_DEFINITION_ID, SUBJECT_ID, COHORT_START_DATE,
#'                                         COHORT_END_DATE.
#' @param cdmVersion                   Define the OMOP CDM version used: currently support "4", "5" and "6".
#' @param firstExposureOnly            Should only the first exposure per subject be included? Note that
#'                                     this is typically done in the \code{createStudyPopulation} function,
#'                                     but can already be done here for efficiency reasons.
#' @param washoutPeriod                The mininum required continuous observation time prior to index
#'                                     date for a person to be included in the at risk cohort. Note that
#'                                     this is typically done in the \code{createStudyPopulation} function,
#'                                     but can already be done here for efficiency reasons.
#' @param sampleSize                   If not NULL, only this number of people will be sampled from the target population (Default NULL)
#' 
#' @param covariateSettings            An object of type \code{covariateSettings} as created using the
#'                                     \code{createCovariateSettings} function in the
#'                                     \code{FeatureExtraction} package.
#' @param excludeDrugsFromCovariates   A redundant option                                     
#'
#' @return
#' Returns an object of type \code{plpData}, containing information on the cohorts, their
#' outcomes, and baseline covariates. Information about multiple outcomes can be captured at once for
#' efficiency reasons. This object is a list with the following components: \describe{
#' \item{outcomes}{A data frame listing the outcomes per person, including the time to event, and
#' the outcome id. Outcomes are not yet filtered based on risk window, since this is done at
#' a later stage.} \item{cohorts}{A data frame listing the persons in each cohort, listing their
#' exposure status as well as the time to the end of the observation period and time to the end of the
#' cohort (usually the end of the exposure era).} \item{covariates}{An ffdf object listing the
#' baseline covariates per person in the two cohorts. This is done using a sparse representation:
#' covariates with a value of 0 are omitted to save space.} \item{covariateRef}{An ffdf object describing the covariates that have been extracted.}
#' \item{metaData}{A list of objects with information on how the cohortMethodData object was
#' constructed.} } The generic \code{()} and \code{summary()} functions have been implemented for this object.
#'
#' @export
getPlpData <- function(connectionDetails,
                       cdmDatabaseSchema,
                       oracleTempSchema = cdmDatabaseSchema,
                       cohortId,
                       outcomeIds,
                       studyStartDate = "",
                       studyEndDate = "",
                       cohortDatabaseSchema = cdmDatabaseSchema,
                       cohortTable = "cohort",
                       outcomeDatabaseSchema = cdmDatabaseSchema,
                       outcomeTable = "cohort",
                       cdmVersion = "5",
                       firstExposureOnly = FALSE,
                       washoutPeriod = 0,
                       sampleSize = NULL,
                       covariateSettings,
                       excludeDrugsFromCovariates = FALSE) {
  if (studyStartDate != "" && regexpr("^[12][0-9]{3}[01][0-9][0-3][0-9]$", studyStartDate) == -1)
    stop("Study start date must have format YYYYMMDD")
  if (studyEndDate != "" && regexpr("^[12][0-9]{3}[01][0-9][0-3][0-9]$", studyEndDate) == -1)
    stop("Study end date must have format YYYYMMDD")
  if(!is.null(sampleSize)){
    if(!class(sampleSize) %in% c('numeric', 'integer'))
      stop("sampleSize must be numeric")
  }
  
  if(is.null(cohortId))
    stop('User must input cohortId')
  if(length(cohortId)>1)
    stop('Currently only supports one cohortId at a time')
  if(is.null(outcomeIds))
    stop('User must input outcomeIds')
  #ToDo: add other checks the inputs are valid
  
  connection <- DatabaseConnector::connect(connectionDetails)
  on.exit(DatabaseConnector::disconnect(connection))
  dbms <- connectionDetails$dbms
  
  writeLines("\nConstructing the at risk cohort")
  if(!is.null(sampleSize))  writeLines(paste("\n Sampling ",sampleSize, " people"))
  renderedSql <- SqlRender::loadRenderTranslateSql("CreateCohorts.sql",
                                                   packageName = "PatientLevelPrediction",
                                                   dbms = dbms,
                                                   oracleTempSchema = oracleTempSchema,
                                                   cdm_database_schema = cdmDatabaseSchema,
                                                   cohort_database_schema = cohortDatabaseSchema,
                                                   cohort_table = cohortTable,
                                                   cdm_version = cdmVersion,
                                                   cohort_id = cohortId,
                                                   study_start_date = studyStartDate,
                                                   study_end_date = studyEndDate,
                                                   first_only = firstExposureOnly,
                                                   washout_period = washoutPeriod,
                                                   use_sample = !is.null(sampleSize),
                                                   sample_number=sampleSize
  )
  DatabaseConnector::executeSql(connection, renderedSql)
  
  writeLines("Fetching cohorts from server")
  start <- Sys.time()
  cohortSql <- SqlRender::loadRenderTranslateSql("GetCohorts.sql",
                                                 packageName = "PatientLevelPrediction",
                                                 dbms = dbms,
                                                 oracleTempSchema = oracleTempSchema,
                                                 cdm_version = cdmVersion)
  cohorts <- DatabaseConnector::querySql(connection, cohortSql)
  colnames(cohorts) <- SqlRender::snakeCaseToCamelCase(colnames(cohorts))
  metaData.cohort <- list(cohortId = cohortId,
                          studyStartDate = studyStartDate,
                          studyEndDate = studyEndDate)
  
  if(nrow(cohorts)==0)
    stop('Target population is empty')
  
  delta <- Sys.time() - start
  writeLines(paste("Loading cohorts took", signif(delta, 3), attr(delta, "units")))
  
  #covariateSettings$useCovariateCohortIdIs1 <- TRUE
  covariateData <- FeatureExtraction::getDbCovariateData(connection = connection,
                                                         oracleTempSchema = oracleTempSchema,
                                                         cdmDatabaseSchema = cdmDatabaseSchema,
                                                         cdmVersion = cdmVersion,
                                                         cohortTable = "#cohort_person",
                                                         cohortTableIsTemp = TRUE,
                                                         rowIdField = "row_id",
                                                         covariateSettings = covariateSettings)
  # add indexes for covariate summary
  RSQLite::dbExecute(covariateData, "CREATE INDEX covsum_rowId ON covariates(rowId)")
  RSQLite::dbExecute(covariateData, "CREATE INDEX covsum_covariateId ON covariates(covariateId)")
  
  
  if(max(outcomeIds)!=-999){
    writeLines("Fetching outcomes from server")
    start <- Sys.time()
    outcomeSql <- SqlRender::loadRenderTranslateSql("GetOutcomes.sql",
                                                    packageName = "PatientLevelPrediction",
                                                    dbms = dbms,
                                                    oracleTempSchema = oracleTempSchema,
                                                    cdm_database_schema = cdmDatabaseSchema,
                                                    outcome_database_schema = outcomeDatabaseSchema,
                                                    outcome_table = outcomeTable,
                                                    outcome_ids = outcomeIds,
                                                    cdm_version = cdmVersion)
    outcomes <- DatabaseConnector::querySql(connection, outcomeSql)
    colnames(outcomes) <- SqlRender::snakeCaseToCamelCase(colnames(outcomes))
    metaData.outcome <- data.frame(outcomeIds =outcomeIds)
    attr(outcomes, "metaData") <- metaData.outcome
    if(nrow(outcomes)==0){
      metaData.cohort$attrition <- getCounts3(cohorts,outcomes, outcomeIds, "Original cohorts")  
    }else{
      metaData.cohort$attrition <- getCounts2(cohorts,outcomes, "Original cohorts")  
    }
    
    attr(cohorts, "metaData") <- metaData.cohort
    
    delta <- Sys.time() - start
    writeLines(paste("Loading outcomes took", signif(delta, 3), attr(delta, "units")))
  } else {
    outcomes <- NULL
  }
  
  # Remove temp tables:
  renderedSql <- SqlRender::loadRenderTranslateSql("RemoveCohortTempTables.sql",
                                                   packageName = "PatientLevelPrediction",
                                                   dbms = dbms,
                                                   oracleTempSchema = oracleTempSchema)
  DatabaseConnector::executeSql(connection, renderedSql, progressBar = FALSE, reportOverallTime = FALSE)
  #DatabaseConnector::disconnect(connection)
  
  metaData <- covariateData$metaData
  metaData$call <- match.call()
  metaData$call$connectionDetails = connectionDetails
  metaData$call$connection = NULL
  metaData$call$cdmDatabaseSchema = cdmDatabaseSchema
  metaData$call$oracleTempSchema = oracleTempSchema
  metaData$call$cohortId = cohortId
  metaData$call$outcomeIds = outcomeIds
  metaData$call$studyStartDate = studyStartDate
  metaData$call$studyEndDate = studyEndDate
  metaData$call$cohortDatabaseSchema = cohortDatabaseSchema
  metaData$call$cohortTable = cohortTable
  metaData$call$outcomeDatabaseSchema = outcomeDatabaseSchema
  metaData$call$outcomeTable = outcomeTable
  metaData$call$cdmVersion = cdmVersion
  metaData$call$firstExposureOnly = firstExposureOnly
  metaData$call$washoutPeriod = washoutPeriod
  metaData$call$covariateSettings= covariateSettings
  metaData$call$sampleSize = sampleSize
  
  # create the temporal settings (if temporal use)
  timeReference <- NULL
  if(!is.null(covariateSettings$temporal)){
    if(covariateSettings$temporal){
      # make sure time days populated
      if(length(covariateSettings$temporalStartDays)>0){
        timeReference = data.frame(timeId=1:length(covariateSettings$temporalStartDays),
                                   startDay = covariateSettings$temporalStartDays, 
                                   endDay = covariateSettings$temporalEndDays)
      }
    }}
  
  
  result <- list(cohorts = cohorts,
                 outcomes = outcomes,
                 covariateData = covariateData,
                 timeRef = timeReference,
                 metaData = metaData)
  
  class(result) <- "plpData"
  return(result)
}


#' Save the cohort data to folder
#'
#' @description
#' \code{savePlpData} saves an object of type plpData to folder.
#'
#' @param plpData   An object of type \code{plpData} as generated using
#'                           \code{getDbPlpData}.
#' @param file               The name of the folder where the data will be written. The folder should
#'                           not yet exist.
#' @param envir              The environment for to evaluate variables when saving
#' @param overwrite          Whether to force overwrite an existing file
#' @details
#' The data will be written to a set of files in the folder specified by the user.
#'
#' @examples
#' # todo
#'
#' @export
savePlpData <- function(plpData, file, envir=NULL, overwrite=F) {
  if (missing(plpData))
    stop("Must specify plpData")
  if (missing(file))
    stop("Must specify file")
  if (!class(plpData) %in% c("plpData","plpData.libsvm"  ))
    stop("Data not of class plpData")
  if(dir.exists(file.path(file, "covariates"))){
    stop('Folder to save covariates already exists...')
  }
  
  if(!dir.exists(file)){
    dir.create(file)
  }
  
  # save the actual values in the metaData
  # TODO - only do this if exists in parent or environ
  if(is.null(plpData$metaData$call$sampleSize)){  # fixed a bug when sampleSize is NULL
    plpData$metaData$call$sampleSize <- 'NULL'
  }
  for(i in 2:length(plpData$metaData$call)){
    if(!is.null(plpData$metaData$call[[i]]))
      plpData$metaData$call[[i]] <- eval(plpData$metaData$call[[i]], envir = envir)
  }
  
  #FeatureExtraction::saveCovariateData(covariateData = plpData$covariateData, file = file.path(file, "covariates"))
  Andromeda::saveAndromeda(plpData$covariateData, file = file.path(file, "covariates"), maintainConnection = T)
  saveRDS(plpData$timeRef, file = file.path(file, "timeRef.rds"))
  saveRDS(plpData$cohorts, file = file.path(file, "cohorts.rds"))
  saveRDS(plpData$outcomes, file = file.path(file, "outcomes.rds"))
  saveRDS(plpData$metaData, file = file.path(file, "metaData.rds"))
}

#' Load the cohort data from a folder
#'
#' @description
#' \code{loadPlpData} loads an object of type plpData from a folder in the file
#' system.
#'
#' @param file       The name of the folder containing the data.
#' @param readOnly   If true, the data is opened read only.
#'
#' @details
#' The data will be written to a set of files in the folder specified by the user.
#'
#' @return
#' An object of class plpData.
#'
#' @examples
#' # todo
#'
#' @export
loadPlpData <- function(file, readOnly = TRUE) {
  if (!file.exists(file))
    stop(paste("Cannot find folder", file))
  if (!file.info(file)$isdir)
    stop(paste("Not a folder", file))
  
  result <- list(covariateData = FeatureExtraction::loadCovariateData(file = file.path(file, "covariates")),
                 timeRef = readRDS(file.path(file, "timeRef.rds")),
                 cohorts = readRDS(file.path(file, "cohorts.rds")),
                 outcomes = readRDS(file.path(file, "outcomes.rds")),
                 metaData = readRDS(file.path(file, "metaData.rds")))
  # Open all ffdfs to prevent annoying messages later:
  class(result) <- "plpData"
  
  return(result)
}

#' @export
print.plpData <- function(x, ...) {
  writeLines("plpData object")
  writeLines("")
  writeLines(paste("At risk concept ID:", attr(x$cohorts, "metaData")$cohortId))
  writeLines(paste("Outcome concept ID(s):", paste(attr(x$outcomes, "metaData")$outcomeIds, collapse = ",")))
}

#' @export
summary.plpData <- function(object, ...) {
  people <- length(unique(object$cohorts$subjectId))
  outcomeCounts <- data.frame(outcomeId = attr(object$outcomes, "metaData")$outcomeIds,
                              eventCount = 0,
                              personCount = 0)
  for (i in 1:nrow(outcomeCounts)) {
    outcomeCounts$eventCount[i] <- sum(object$outcomes$outcomeId == attr(object$outcomes, "metaData")$outcomeIds[i])
    outcomeCounts$personCount[i] <- length(unique(object$outcomes$rowId[object$outcomes$outcomeId == attr(object$outcomes, "metaData")$outcomeIds[i]]))
  }
  
  covDetails <- FeatureExtraction::summary(object$covariateData)
  result <- list(metaData = append(append(object$metaData, attr(object$cohorts, "metaData")), attr(object$outcomes, "metaData")),
                 people = people,
                 outcomeCounts = outcomeCounts,
                 covariateCount = covDetails$covariateCount,
                 covariateValueCount = covDetails$covariateValueCount)
  class(result) <- "summary.plpData"
  return(result)
}

#' @export
print.summary.plpData <- function(x, ...) {
  writeLines("plpData object summary")
  writeLines("")
  writeLines(paste("At risk cohort concept ID:", x$metaData$cohortId))
  writeLines(paste("Outcome concept ID(s):", x$metaData$outcomeIds, collapse = ","))
  writeLines("")
  writeLines(paste("People:", paste(x$people)))
  writeLines("")
  writeLines("Outcome counts:")
  outcomeCounts <- x$outcomeCounts
  rownames(outcomeCounts) <- outcomeCounts$outcomeId
  outcomeCounts$outcomeId <- NULL
  colnames(outcomeCounts) <- c("Event count", "Person count")
  printCoefmat(outcomeCounts)
  writeLines("")
  writeLines("Covariates:")
  writeLines(paste("Number of covariates:", x$covariateCount))
  writeLines(paste("Number of non-zero covariate values:", x$covariateValueCount))
}


#' Saves the plp model
#'
#' @details
#' Saves the plp model to a user specificed folder
#'
#' @param plpModel                   A trained classifier returned by running \code{runPlp()$model}
#' @param dirPath                  A location to save the model to
#'
#' @export
savePlpModel <- function(plpModel, dirPath){
  if (missing(plpModel))
    stop("Must specify plpModel")
  if (missing(dirPath))
    stop("Must specify directory path")
  if (class(plpModel) != "plpModel")
    stop("Not a plpModel")
  
  if(!dir.exists(dirPath)) dir.create(dirPath)
  
  
  # If model is saved on hard drive move it...
  #============================================================
  moveFile <- moveHdModel(plpModel, dirPath )
  if(!moveFile){
    ParallelLogger::logError('Moving model files error')
  }
  #============================================================
  
  
  # if deep (keras) then save hdfs
  if(attr(plpModel, 'type')%in%c('deep', 'deepMulti','deepEnsemble')){
    
    if(attr(plpModel, 'type')=='deepEnsemble'){
      tryCatch(
        {#saveRDS(plpModel, file = file.path(dirPath,  "deepEnsemble_model.rds"))
          for (i in seq(plpModel$modelSettings$modelParameters$numberOfEnsembleNetwork)){
            model<-keras::serialize_model(plpModel$model[[i]], include_optimizer = TRUE)
            keras::save_model_hdf5(model, filepath = file.path(dirPath, "keras_model",i))
          }},error=function(e) NULL
      )
    }
    if(attr(plpModel, 'type')=='deep'){
      keras::save_model_hdf5(plpModel$model, filepath = file.path(dirPath, "keras_model"))
    }
    if(attr(plpModel, 'type')=='deepMulti'){
      saveRDS(attr(plpModel, 'inputs'), file = file.path(dirPath,  "inputs_attr.rds"))
    }
    } else if(attr(plpModel, 'type') == "xgboost"){
    # fixing xgboost save/load issue
    xgboost::xgb.save(model = plpModel$model, fname = file.path(dirPath, "model"))
  } else {  
    saveRDS(plpModel$model, file = file.path(dirPath, "model.rds"))
  }
  saveRDS(plpModel$predict, file = file.path(dirPath, "transform.rds"))
  saveRDS(plpModel$index, file = file.path(dirPath, "index.rds"))
  saveRDS(plpModel$trainCVAuc, file = file.path(dirPath, "trainCVAuc.rds"))
  saveRDS(plpModel$hyperParamSearch, file = file.path(dirPath, "hyperParamSearch.rds"))
  saveRDS(plpModel$modelSettings, file = file.path(dirPath,  "modelSettings.rds"))
  saveRDS(plpModel$metaData, file = file.path(dirPath, "metaData.rds"))
  saveRDS(plpModel$populationSettings, file = file.path(dirPath, "populationSettings.rds"))
  saveRDS(plpModel$trainingTime, file = file.path(dirPath,  "trainingTime.rds"))
  saveRDS(plpModel$varImp, file = file.path(dirPath,  "varImp.rds"))
  saveRDS(plpModel$dense, file = file.path(dirPath,  "dense.rds"))
  saveRDS(plpModel$cohortId, file = file.path(dirPath,  "cohortId.rds"))
  saveRDS(plpModel$outcomeId, file = file.path(dirPath,  "outcomeId.rds"))
  saveRDS(plpModel$analysisId, file = file.path(dirPath,  "analysisId.rds"))
  #if(!is.null(plpModel$covariateMap))
  saveRDS(plpModel$covariateMap, file = file.path(dirPath,  "covariateMap.rds"))
  
  attributes <- list(type=attr(plpModel, 'type'), predictionType=attr(plpModel, 'predictionType') )
  saveRDS(attributes, file = file.path(dirPath,  "attributes.rds"))
  
  
}

moveHdModel <- function(plpModel, dirPath ){
  #==================================================================
  # if python then move pickle
  #==================================================================
  if(attr(plpModel, 'type') %in% c('pythonOld','pythonReticulate', 'pythonAuto') ){
    if(!dir.exists(file.path(dirPath,'python_model')))
      dir.create(file.path(dirPath,'python_model'))
    for(file in dir(plpModel$model)){   #DOES THIS CORRECTLY TRANSFER AUTOENCODER BITS?
      file.copy(file.path(plpModel$model,file), 
                file.path(dirPath,'python_model'), overwrite=TRUE,  recursive = FALSE,
                copy.mode = TRUE, copy.date = FALSE)
    }
  }
  
  #==================================================================
  # if sagemaker then move pickle
  #==================================================================
  if(attr(plpModel, 'type') =='sagemaker'){
    if(!dir.exists(file.path(dirPath,'sagemaker_model')))
      dir.create(file.path(dirPath,'sagemaker_model'))
    for(file in dir(plpModel$model$loc)){
      file.copy(file.path(plpModel$model$loc,file), 
                file.path(dirPath,'sagemaker_model'), overwrite=TRUE,  recursive = FALSE,
                copy.mode = TRUE, copy.date = FALSE)
    }
  }
  
  #==================================================================
  # if knn then move model
  #==================================================================
  if(attr(plpModel, 'type') =='knn'){
    if(!dir.exists(file.path(dirPath,'knn_model')))
      dir.create(file.path(dirPath,'knn_model'))
    for(file in dir(plpModel$model)){
      file.copy(file.path(plpModel$model,file), 
                file.path(dirPath,'knn_model'), overwrite=TRUE,  recursive = FALSE,
                copy.mode = TRUE, copy.date = FALSE)
    }
  }
  
  return(TRUE)
}

#' loads the plp model
#'
#' @details
#' Loads a plp model that was saved using \code{savePlpModel()}
#'
#' @param dirPath                  The location of the model
#'
#' @export
loadPlpModel <- function(dirPath) {
  if (!file.exists(dirPath))
    stop(paste("Cannot find folder", dirPath))
  if (!file.info(dirPath)$isdir)
    stop(paste("Not a folder", dirPath))
  
  hyperParamSearch <- tryCatch(readRDS(file.path(dirPath, "hyperParamSearch.rds")),
                               error=function(e) NULL)
  # add in these as they got dropped
  outcomeId <- tryCatch(readRDS(file.path(dirPath, "outcomeId.rds")),
                        error=function(e) NULL)
  cohortId <- tryCatch(readRDS(file.path(dirPath, "cohortId.rds")),
                       error=function(e) NULL)  
  dense <- tryCatch(readRDS(file.path(dirPath, "dense.rds")),
                    error=function(e) NULL)  
  covariateMap <- tryCatch(readRDS(file.path(dirPath, "covariateMap.rds")),
                           error=function(e) NULL) 
  analysisId <- tryCatch(readRDS(file.path(dirPath, "analysisId.rds")),
                         error=function(e) NULL) 
  
  if(file.exists(file.path(dirPath, "keras_model"))){
    model <- keras::load_model_hdf5(file.path(dirPath, "keras_model"))
  } else if(readRDS(file.path(dirPath, "attributes.rds"))$type == "xgboost"){
  # fixing xgboost save/load issue
  model <- xgboost::xgb.load(file.path(dirPath, "model"))
  } else {  
    model <- readRDS(file.path(dirPath, "model.rds"))
  }
  
  result <- list(model = model,
                 modelSettings = readRDS(file.path(dirPath, "modelSettings.rds")),
                 hyperParamSearch = hyperParamSearch,
                 trainCVAuc = readRDS(file.path(dirPath, "trainCVAuc.rds")),
                 metaData = readRDS(file.path(dirPath, "metaData.rds")),
                 populationSettings= readRDS(file.path(dirPath, "populationSettings.rds")),
                 outcomeId = outcomeId,
                 cohortId = cohortId,
                 varImp = readRDS(file.path(dirPath, "varImp.rds")),
                 trainingTime = readRDS(file.path(dirPath, "trainingTime.rds")),
                 covariateMap =covariateMap,
                 predict = readRDS(file.path(dirPath, "transform.rds")),
                 index = readRDS(file.path(dirPath, "index.rds")),
                 dense = dense,
                 analysisId = analysisId)
  
  #attributes <- readRDS(file.path(dirPath, "attributes.rds"))
  attributes <- readRDS(file.path(dirPath, "attributes.rds"))
  attr(result, 'type') <- attributes$type
  attr(result, 'predictionType') <- attributes$predictionType
  class(result) <- "plpModel"
  
  # update the model location to the load dirPath
  result <- updateModelLocation(result, dirPath)
  
  # make this backwrds compatible for ffdf:
  result$predict <- createTransform(result)
  
  return(result)
}

updateModelLocation  <- function(plpModel, dirPath){
  type <- attr(plpModel, 'type')
  # if python update the location
  if( type %in% c('pythonOld','pythonReticulate', 'pythonAuto')){
    plpModel$model <- file.path(dirPath,'python_model')
    plpModel$predict <- createTransform(plpModel)
  }
  if( type =='sagemaker'){
    plpModel$model$loc <- file.path(dirPath,'sagemaker_model')
    plpModel$predict <- createTransform(plpModel)
  }
  # if knn update the locaiton - TODO !!!!!!!!!!!!!!
  if( type =='knn'){
    plpModel$model <- file.path(dirPath,'knn_model')
    plpModel$predict <- createTransform(plpModel)
  }
  if( type =='deep' ){
    plpModel$predict <- createTransform(plpModel)
  }
  if( type =='deepEnsemble' ){
    plpModel$predict <- createTransform(plpModel)
  }
  if( type =='deepMulti'){
    attr(plpModel, 'inputs') <- tryCatch(readRDS(file.path(dirPath, "inputs_attr.rds")),
                                         error=function(e) NULL) 
    plpModel$predict <- createTransform(plpModel)
    
  }
  
  return(plpModel) 
}


#' Saves the prediction dataframe to RDS
#'
#' @details
#' Saves the prediction data frame returned by predict.R to an RDS file and returns the fileLocation where the prediction is saved
#'
#' @param prediction                   The prediciton data.frame
#' @param dirPath                     The directory to save the prediction RDS
#' @param fileName                    The name of the RDS file that will be saved in dirPath
#' 
#' @export
savePrediction <- function(prediction, dirPath, fileName='prediction.rds'){
  #TODO check inupts
  saveRDS(prediction, file=file.path(dirPath,fileName))
  
  return(file.path(dirPath,fileName))
}

#' Loads the prediciton dataframe to csv
#'
#' @details
#' Loads the prediciton  RDS file
#'
#' @param fileLocation                     The location with the saved prediction
#' 
#' @export
loadPrediction <- function(fileLocation){
  #TODO check inupts
  prediction <- readRDS(file=fileLocation)
  return(prediction)
}

#' Saves the result from runPlp into the location directory
#'
#' @details
#' Saves the result from runPlp into the location directory
#'
#' @param result                      The result of running runPlp()
#' @param dirPath                     The directory to save the csv
#' 
#' @export
savePlpResult <- function(result, dirPath){
  if (missing(result))
    stop("Must specify runPlp output")
  if (missing(dirPath))
    stop("Must specify directory location")
  #if (class(plpModel) != "plpModel")
  #  stop("Not a plpModel")
  
  if(!dir.exists(dirPath)) dir.create(dirPath, recursive = T)
  
  savePlpModel(result$model, dirPath=file.path(dirPath,'model') )
  saveRDS(result$analysisRef, file = file.path(dirPath, "analysisRef.rds"))
  saveRDS(result$inputSetting, file = file.path(dirPath, "inputSetting.rds"))
  saveRDS(result$executionSummary, file = file.path(dirPath, "executionSummary.rds"))
  saveRDS(result$prediction, file = file.path(dirPath, "prediction.rds"))
  saveRDS(result$performanceEvaluation, file = file.path(dirPath, "performanceEvaluation.rds"))
  #saveRDS(result$performanceEvaluationTrain, file = file.path(dirPath, "performanceEvaluationTrain.rds"))
  saveRDS(result$covariateSummary, file = file.path(dirPath, "covariateSummary.rds"))
  
  
}

#' Loads the evalaution dataframe
#'
#' @details
#' Loads the evaluation 
#'
#' @param dirPath                     The directory where the evaluation was saved
#' 
#' @export
loadPlpResult <- function(dirPath){
  if (!file.exists(dirPath))
    stop(paste("Cannot find folder", dirPath))
  if (!file.info(dirPath)$isdir)
    stop(paste("Not a folder", dirPath))
  
  
  result <- list(model = loadPlpModel(file.path(dirPath, "model")),
                 analysisRef = readRDS(file.path(dirPath, "analysisRef.rds")),
                 inputSetting = readRDS(file.path(dirPath, "inputSetting.rds")),
                 executionSummary = readRDS(file.path(dirPath, "executionSummary.rds")),
                 prediction = readRDS(file.path(dirPath, "prediction.rds")),
                 performanceEvaluation = readRDS(file.path(dirPath, "performanceEvaluation.rds")),
                 #performanceEvaluationTrain= readRDS(file.path(dirPath, "performanceEvaluationTrain.rds")),
                 covariateSummary = readRDS(file.path(dirPath, "covariateSummary.rds"))
  )
  class(result) <- "runPlp"
  
  return(result)
  
}
ted9219/CoDImputationHeart documentation built on Sept. 15, 2020, 11:30 a.m.