R/FamiliarModel.R

Defines functions .get_available_risklike_prediction_types

#' @include FamiliarS4Generics.R
#' @include FamiliarS4Classes.R
NULL



# .train (familiarModel, dataObject) -------------------------------------------
setMethod(
  ".train",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(
    object,
    data,
    get_additional_info = FALSE,
    is_pre_processed = FALSE,
    trim_model = TRUE,
    timeout = 60000,
    approximate = FALSE,
    ...) {
    # Train method for model training

    # Check if the class of object is a subclass of familiarModel.
    if (!is_subclass(class(object)[1], "familiarModel")) object <- promote_learner(object)

    # Process data, if required.
    data <- process_input_data(
      object = object,
      data = data,
      is_pre_processed = is_pre_processed,
      stop_at = "clustering")

    # Work only with data that has known outcomes when training.
    data <- filter_missing_outcome(data = data)

    # Set the training flags
    can_train <- can_train_naive <- TRUE

    # Check if there are any data entries. The familiar model cannot be trained
    # otherwise. We do allow for no features being present.
    if (is_empty(x = data, allow_no_features = TRUE)) {
      can_train <- can_train_naive <- FALSE
      object <- ..update_errors(
        object = object,
        ..error_message_no_training_data_available())
    }

    # Check the number of features in data; if it has no features, a standard
    # familiar model can not be trained. However, it might be possible to train
    # a naive model.
    if (!has_feature_data(x = data)) {
      can_train <- FALSE
      object <- ..update_errors(
        object = object,
        ..error_message_no_features_selected_for_training())
    }

    # Check if the hyperparameters are plausible.
    if (!has_optimised_hyperparameters(object = object)) {
      can_train <- can_train_naive <- FALSE
      object <- ..update_errors(
        object = object,
        ..error_message_no_optimised_hyperparameters_available())
    }

    # Check if a naive model should be trained.
    if (!requires_naive_model(object)) {
      can_train_naive <- FALSE
    }

    # Add outcome distribution data.
    object@outcome_info <- .compute_outcome_distribution_data(
      object = object@outcome_info,
      data = data)

    # Train a new model based on data. If a normal model cannot be trained,
    # attempt to train a naive model.
    if (can_train) {
      object <- ..train(
        object = object,
        data = data,
        approximate = approximate,
        ...)
      
    } else if (can_train_naive) {
      object <- ..train_naive(
        object = object,
        data = data,
        ...)
    }

    # Extract information required for assessing model performance, calibration
    # (e.g. baseline survival) etc.
    if (get_additional_info) {
      # Remove duplicate samples from the data prior to obtaining additional
      # data.
      data <- aggregate_data(data = data)

      # Create calibration models and add to the object. Not all models require
      # recalibration.
      if (can_train) {
        object <- ..set_recalibration_model(
          object = object,
          data = data)
      }

      # Extract data required for assessing calibration. Not all outcome types
      # require calibration info. Currently calibration information is only
      # retrieved for survival outcomes, in the form of baseline survival
      # curves.
      if (can_train || can_train_naive) {
        object <- ..set_calibration_info(
          object = object,
          data = data)
      }

      # Set stratification thresholds. This is currently only done for
      # survival outcomes.
      if (can_train) {
        object <- ..set_risk_stratification_thresholds(
          object = object,
          data = data)
      }

      # Add column data
      object <- add_data_column_info(
        object = object,
        data = data)
    }

    if (trim_model) object <- trim_model(object = object, timeout = timeout)

    # Empty slots if a model can not be trained.
    if (!can_train) {
      object@required_features <- NULL
      object@model_features <- NULL
      object@novelty_features <- NULL
    }

    return(object)
  }
)


# .train (familiarModel, NULL)--------------------------------------------------
setMethod(
  ".train", signature(
    object = "familiarModel",
    data = "NULL"),
  function(
    object,
    data,
    ...) {
    # The model cannot be trained, and is returned directly.
    return(object)
  }
)


# .train_novelty_detector ----------------------------------------------------
setMethod(
  ".train_novelty_detector",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(
    object,
    data,
    detector,
    get_additional_info = FALSE,
    is_pre_processed = FALSE,
    trim_model = TRUE,
    ...) {
    # Train method for novelty detectors.

    # Check if the class of object is a subclass of familiarModel.
    if (!is_subclass(class(object)[1], "familiarModel")) object <- promote_learner(object)

    # Create detector object.
    fam_detector <- methods::new("familiarNoveltyDetector",
      learner = detector,
      feature_info = object@feature_info,
      required_features = object@required_features,
      model_features = object@novelty_features,
      run_table = object@run_table,
      project_id = object@project_id)

    # Promote to the correct type of detector.
    fam_detector <- promote_detector(object = fam_detector)

    # Process data, if required.
    data <- process_input_data(
      object = fam_detector,
      data = data,
      is_pre_processed = is_pre_processed,
      stop_at = "clustering",
      force_check = TRUE)

    # Optimise hyperparameters if they were not previously set.
    if (!has_optimised_hyperparameters(object = fam_detector)) {
      fam_detector <- optimise_hyperparameters(
        object = fam_detector,
        data = data,
        ...)
    }

    # Train the novelty detector.
    fam_detector <- .train(
      object = fam_detector,
      data = data)

    # Add the detector to the familiarModel object.
    object@novelty_detector <- fam_detector

    return(object)
  }
)


# show (model) -----------------------------------------------------------------
setMethod(
  "show",
  signature(object = "familiarModel"),
  function(object) {
    # Make sure the model object is updated.
    object <- update_object(object = object)

    if (!model_is_trained(object)) {
      cat(paste0(
        "A ", object@learner, " model (class: ", class(object)[1],
        ") that was not successfully trained (", 
        .familiar_version_string(object), ").\n"))

      if (length(object@messages$warning) > 0) {
        condition_messages <- condition_summary(object@messages$warning)
        cat(paste0(
          "\nThe following ",
          ifelse(length(condition_messages) == 1, "warning was", "warnings were"),
          " generated while trying to train the model:\n",
          paste0(condition_messages, collapse = "\n"),
          "\n"))
      }

      if (length(object@messages$error) > 0) {
        condition_messages <- condition_summary(object@messages$error)
        cat(paste0(
          "\nThe following ",
          ifelse(length(condition_messages) == 1, "error was", "errors were"),
          " encountered while trying to train the model:\n",
          paste0(condition_messages, collapse = "\n"),
          "\n"))
      }
      
      return(invisible(NULL))
    }
    
    # Describe the learner and the version of familiar.
    message_str <- paste0(
      "A ", object@learner, " model (class: ", class(object)[1],
      "; ", .familiar_version_string(object), ")")
    
    # Describe the package(s), if any
    if (!is.null(object@package)) {
      message_str <- c(
        message_str,
        paste0(" trained using "),
        paste_s(mapply(
          ..message_package_version,
          x = object@package,
          version = object@package_version)),
        ifelse(length(object@package) > 1, " packages", " package"))
    }
    
    # Complete message and write.
    message_str <- paste0(c(message_str, ".\n"), collapse = "")
    cat(message_str)
    
    cat(paste0("\n--------------- Model details ---------------\n"))
    
    # Model details
    if (object@is_trimmed) {
      cat(object@trimmed_function$show, sep = "\n")
    } else {
      show(object@model)
    }
    
    cat(paste0("\n---------------------------------------------\n"))
    
    # Outcome details
    cat("\nThe following outcome was modelled:\n")
    show(object@outcome_info)
    
    # Details concerning hyperparameters.
    cat("\nThe model was trained using the following hyperparameters:\n")
    invisible(lapply(
      names(object@hyperparameters),
      function(x, object) {
        cat(paste0("  ", x, ": ", object@hyperparameters[[x]], "\n"))
      },
      object = object))
    
    # Details concerning variable importance.
    cat(paste0(
      "\nVariable importance was determined using the ",
      object@fs_method, " variable importance method.\n"))
    
    # Details concerning model features:
    cat("\nThe following features were used in the model:\n")
    lapply(
      object@model_features,
      function(x, object) show(object@feature_info[[x]]),
      object = object)
    
    # Details concerning novelty features:
    if (!model_is_trained(object@novelty_detector)) {
      cat("\nNo novelty detector was trained.\n")
      
    } else if (setequal(object@model_features, object@novelty_features)) {
      cat("\nA novelty detector was trained using the model features.\n")
      
    } else {
      cat("\nA novelty detector was trained using the model features above, and additionally:\n\n")
      
      # Identify novelty features that were set in addition to model
      # features.
      novelty_features <- setdiff(object@novelty_features, object@model_features)
      
      lapply(
        novelty_features, 
        function(x, object) show(object@feature_info[[x]]), 
        object = object)
    }
    
    if (length(object@messages$warning) > 0 || length(object@messages$error) > 0) {
      cat(paste0("\n------------ Warnings and errors ------------\n"))
      
      if (length(object@messages$warning) > 0) {
        condition_messages <- condition_summary(object@messages$warning)
        cat(paste0(
          "\nThe following ",
          ifelse(length(condition_messages) == 1, "warning was", "warnings were"),
          " generated while training the model:\n",
          paste0(condition_messages, collapse = "\n")))
      }
      
      if (length(object@messages$error) > 0) {
        condition_messages <- condition_summary(object@messages$error)
        cat(paste0(
          "\nThe following ",
          ifelse(length(condition_messages) == 1, "error was", "errors were"),
          " encountered while training the model:\n",
          paste0(condition_messages, collapse = "\n")))
      }
    }
    
    # Check package version.
    check_package_version(object)
  }
)


#' Model summaries
#'
#' `summary` produces model summaries.
#'
#' @param object a familiarModel object
#' @param ... additional arguments passed to `summary` methods for the underlying
#'   model, when available.
#'
#' @details This method extends the `summary` S3 method. For some models
#'   `summary` requires information that is trimmed from the model. In this case
#'   a copy of summary data is stored with the model, and returned.
#'
#' @return Depends on underlying model. See the documentation for the particular
#'   models.
#'
#' @exportMethod summary
#' @rdname summary-methods
#' @md
setGeneric("summary")

# summary-----------------------------------------------------------------------

#' @rdname summary-methods
setMethod(
  "summary",
  signature(object = "familiarModel"),
  function(object, ...) {
    # Make sure the model object is updated.
    object <- update_object(object = object)

    if (!model_is_trained(object)) {
      message("The model was not trained. A summary is not available.")
      return(invisible(NULL))
    }

    # Attempt to retrieve the summary from the model.
    if (object@is_trimmed) {
      if (!is.null(object@trimmed_function$summary)) {
        return(object@trimmed_function$summary)
      }
    }

    # Attempt to capture the summary directly.
    h <- tryCatch(
      summary(object@model, ...),
      error = identity)

    # If an error is generated, create a message and return an invisible NULL.
    if (inherits(h, "error")) {
      message("No summary is available for this model.")
      return(invisible(NULL))
    }

    # If the summary is a default one, create a message and return an invisible
    # NULL.
    if (inherits(h, "summaryDefault")) {
      message("No summary is available for this model.")
      return(invisible(NULL))
    }

    return(h)
  }
)



#' Extract model coefficients
#'
#' @param object a familiarModel object
#' @param ... additional arguments passed to `coef` methods for the underlying
#'   model, when available.
#'
#' @details This method extends the `coef` S3 method. For some models `coef`
#'   requires information that is trimmed from the model. In this case a copy of
#'   the model coefficient is stored with the model, and returned.
#'
#' @return Coefficients extracted from the model in the familiarModel object, if
#'   any.
#' @export
#' @rdname coef-methods
#' @md
setGeneric("coef")

# coef -------------------------------------------------------------------------

#' @rdname coef-methods
setMethod(
  "coef",
  signature(object = "familiarModel"),
  function(object, ...) {
    # Make sure the model object is updated.
    object <- update_object(object = object)

    # No training, no coefficients.
    if (!model_is_trained(object)) return(NULL)

    # Attempt to retrieve the coefficients from the model.
    if (object@is_trimmed) {
      if (!is.null(object@trimmed_function$coef)) {
        return(object@trimmed_function$coef)
      }
    }

    # Attempt to capture the coefficients directly.
    feature_coefficients <- tryCatch(
      coef(object@model, ...),
      error = identity)

    # If an error is generated by coef, return a NULL.
    if (inherits(feature_coefficients, "error")) return(NULL)
    
    return(feature_coefficients)
  }
)



#' Calculate variance-covariance matrix for a model
#'
#' @param object a familiarModel object
#' @param ... additional arguments passed to `vcov` methods for the underlying
#'   model, when available.
#'
#' @details This method extends the `vcov` S3 method. For some models `vcov`
#'   requires information that is trimmed from the model. In this case a copy of
#'   the variance-covariance matrix is stored with the model, and returned.
#'
#' @return Variance-covariance matrix of the model in the familiarModel object,
#'   if any.
#' @export
#' @rdname vcov-methods
#' @md
setGeneric("vcov")

# vcov -------------------------------------------------------------------------

#' @rdname vcov-methods
setMethod(
  "vcov", signature(object = "familiarModel"),
  function(object, ...) {
    # Make sure the model object is updated.
    object <- update_object(object = object)

    # No training, no variance-covariance matrix
    if (!model_is_trained(object)) return(NULL)

    # Attempt to retrieve the variance-covariance matrix from the model.
    if (object@is_trimmed) {
      if (!is.null(object@trimmed_function$vcov)) {
        return(object@trimmed_function$vcov)
      }
    }

    # Attempt to capture the variance-covariance matrix directly.
    variance_covariance_matrix <- tryCatch(
      vcov(object@model, ...),
      error = identity)

    # If an error is generated by vcov, return a NULL.
    if (inherits(variance_covariance_matrix, "error")) return(NULL)

    return(variance_covariance_matrix)
  }
)



# require_package (model) ------------------------------------------------------
setMethod(
  "require_package",
  signature(x = "familiarModel"),
  function(x, purpose = NULL, message_type = "error", ...) {
    # Skip if no package is required.
    if (is_empty(x@package)) return(invisible(TRUE))

    # Set standard purposes for common uses.
    if (!is.null(purpose)) {
      if (purpose %in% c("train", "vimp", "predict", "show", "distribution")) {
        purpose <- switch(purpose,
          "train" = "to train a model",
          "vimp" = "to determine variable importance",
          "predict" = "to create model predictions",
          "show" = "to capture output",
          "distribution" = "to set the model distribution")
      }
    }

    return(invisible(.require_package(
      x = x@package,
      purpose = purpose,
      message_type = message_type)))
  }
)



# set_package_version (model) --------------------------------------------------
setMethod(
  "set_package_version",
  signature(object = "familiarModel"),
  function(object) {
    # Do not add package versions if there are no packages.
    if (is_empty(object@package)) return(object)

    # Obtain package versions.
    object@package_version <- sapply(
      object@package,
      function(x) (as.character(utils::packageVersion(x))))

    return(object)
  }
)



# check_package_version (model) ------------------------------------------------
setMethod(
  "check_package_version",
  signature(object = "familiarModel"),
  function(object) {
    .check_package_version(
      name = object@package,
      version = object@package_version,
      when = "at model creation")
  }
)



# save (model) -----------------------------------------------------------------
setMethod(
  "save",
  signature(
    list = "familiarModel",
    file = "character"),
  function(list, file) {
    .save(object = list, dir_path = file)
  }
)



# add_model_name (ANY, familiarModel)-------------------------------------------
setMethod(
  "add_model_name",
  signature(
    data = "ANY",
    object = "familiarModel"),
  function(data, object) {
    if (is_empty(data)) return(NULL)

    ..error_reached_unreachable_code(
      "add_model_name,any,familiarModel: no method for non-empty data.")
  }
)



# add_model_name (familiarDataElement, familiarModel)---------------------------
setMethod(
  "add_model_name",
  signature(
    data = "familiarDataElement",
    object = "familiarModel"),
  function(data, object) {
    # Determine the model name
    if (length(object@name) == 0) {
      model_name <- get_object_name(object = object, abbreviated = TRUE)
    } else {
      model_name <- object@name
    }

    if (is.null(data@identifiers)) {
      data@identifiers <- list("model_name" = model_name)
    } else {
      data@identifiers[["model_name"]] <- model_name
    }

    return(data)
  }
)



# add_model_name (familiarDataElement, character) ------------------------------
setMethod(
  "add_model_name",
  signature(
    data = "familiarDataElement",
    object = "character"),
  function(data, object) {
    # Load object.
    object <- load_familiar_object(object)

    return(do.call(add_model_name, args = c(list(
      "data" = data,
      "object" = object))))
  }
)



# set_object_name (familiarModel) ----------------------------------------------

#' @title Set the name of a `familiarModel` object.
#'
#' @description Set the `name` slot using the object name.
#'
#' @param x A `familiarModel` object.
#'
#' @return A `familiarModel` object with a generated or a provided name.
#' @md
#' @keywords internal
setMethod(
  "set_object_name",
  signature(x = "familiarModel"),
  function(x, new = NULL) {
    
    if (x@project_id == 0 && is.null(new)) {
      # Generate a random object name. A project_id of 0 means that the objects
      # was auto-generated (i.e. through object conversion). We randomly
      # generate characters and add a time stamp, so that collision is
      # practically impossible.
      slot(object = x, name = "name") <- paste0(
        as.character(as.numeric(format(Sys.time(), "%H%M%S"))),
        "_", rstring(n = 20L))
      
    } else if (is.null(new)) {
      # Generate a sensible object name.
      slot(object = x, name = "name") <- get_object_name(object = x)
      
    } else {
      slot(object = x, name = "name") <- new
    }

    return(x)
  }
)



# get_object_name (model) ------------------------------------------------------
setMethod(
  "get_object_name",
  signature(object = "familiarModel"),
  function(object, abbreviated = FALSE) {
    # Extract data and run id
    model_data_id <- tail(object@run_table, n = 1)$data_id
    model_run_id <- tail(object@run_table, n = 1)$run_id

    if (abbreviated) {
      # Create an abbreviated name
      model_name <- paste0("model.", model_data_id, ".", model_run_id)
      
    } else {
      # Create the full name of the model
      model_name <- get_object_file_name(
        learner = object@learner,
        fs_method = object@fs_method,
        project_id = object@project_id,
        data_id = model_data_id,
        run_id = model_run_id,
        object_type = "familiarModel",
        with_extension = FALSE)
    }

    return(model_name)
  }
)



# model_is_trained (familiarModel) ---------------------------------------------
setMethod(
  "model_is_trained",
  signature(object = "familiarModel"),
  function(object) {
    # Check if a model was trained
    if (is.null(object@model)) {
      # If no model is attached to the object, assume that no model was
      # trained.
      return(FALSE)
      
    } else {
      # If the the model is not NULL, assume that it is present.
      return(TRUE)
    }
  }
)



# model_is_trained (character) -------------------------------------------------
setMethod(
  "model_is_trained",
  signature(object = "character"),
  function(object) {
    # Load object.
    object <- load_familiar_object(object)

    return(do.call(
      model_is_trained,
      args = c(list("object" = object))))
  }
)



# model_is_trained (NULL) ------------------------------------------------------
setMethod(
  "model_is_trained",
  signature(object = "NULL"),
  function(object) {
    return(FALSE)
  }
)



# add_package_version (familiarModel) ------------------------------------------
setMethod(
  "add_package_version",
  signature(object = "familiarModel"),
  function(object) {
    # Set version of familiar
    return(.add_package_version(object = object))
  }
)



# add_data_column_info (familiarModel) -----------------------------------------
setMethod(
  "add_data_column_info",
  signature(object = "familiarModel"),
  function(
    object, 
    data = NULL, 
    sample_id_column = NULL, 
    batch_id_column = NULL, 
    series_id_column = NULL) {
    # Don't determine new column information if this information is already
    # present.
    if (!is.null(object@data_column_info)) return(object)

    # Don't determine new column information if this information can be
    # inherited from a dataObject.
    if (is(data, "dataObject")) {
      if (!is_empty(data@data_column_info)) {
        object@data_column_info <- data@data_column_info

        return(object)
      }
    }

    # Load settings to find identifier columns
    settings <- get_settings()

    # Read from settings. If not set, these will be NULL.
    if (is.null(sample_id_column)) sample_id_column <- settings$data$sample_col
    if (is.null(batch_id_column)) batch_id_column <- settings$data$batch_col
    if (is.null(series_id_column)) series_id_column <- settings$data$series_col

    # Replace any missing.
    if (is.null(sample_id_column)) sample_id_column <- NA_character_
    if (is.null(batch_id_column)) batch_id_column <- NA_character_
    if (is.null(series_id_column)) series_id_column <- NA_character_

    # Repetition column ids are only internal.
    repetition_id_column <- NA_character_

    # Create table
    data_info_table <- data.table::data.table(
      "type" = c("batch_id_column", "sample_id_column", "series_id_column", "repetition_id_column"),
      "internal" = get_id_columns(),
      "external" = c(batch_id_column, sample_id_column, series_id_column, repetition_id_column))

    if (object@outcome_type %in% c("survival", "competing_risk")) {
      # Find internal and external outcome column names.
      internal_outcome_columns <- get_outcome_columns(object@outcome_type)
      external_outcome_columns <- object@outcome_info@outcome_column

      # Add to table
      outcome_info_table <- data.table::data.table(
        "type" = c("outcome_column", "outcome_column"),
        "internal" = internal_outcome_columns,
        "external" = external_outcome_columns)
      
    } else if (object@outcome_type %in% c("binomial", "multinomial", "continuous", "count")) {
      # Find internal and external outcome column names.
      internal_outcome_columns <- get_outcome_columns(object@outcome_type)
      external_outcome_columns <- object@outcome_info@outcome_column

      # Add to table
      outcome_info_table <- data.table::data.table(
        "type" = "outcome_column",
        "internal" = internal_outcome_columns,
        "external" = external_outcome_columns)
      
    } else {
      ..error_no_known_outcome_type(outcome_type = object@outcome_type)
    }

    # Combine into one table and add to object
    object@data_column_info <- rbind(data_info_table, outcome_info_table)

    return(object)
  }
)


# is_available -----------------------------------------------------------------
setMethod(
  "is_available",
  signature(object = "familiarModel"),
  function(object, ...) {
    return(FALSE)
  }
)



# get_default_hyperparameters --------------------------------------------------
setMethod(
  "get_default_hyperparameters",
  signature(object = "familiarModel"),
  function(object, ...) {
    return(list())
  }
)



# ..train (familiarModel, dataObject) ------------------------------------------
setMethod(
  "..train",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(object, data, ...) {
    # Set a NULL model
    object@model <- NULL

    return(object)
  }
)



# ..train (familiarModel, NULL) ------------------------------------------------
setMethod(
  "..train",
  signature(
    object = "familiarModel",
    data = "NULL"),
  function(object, data, ...) {
    # Set a NULL model
    object@model <- NULL

    return(object)
  }
)



# ..train_naive (familiarModel, dataObject) ------------------------------------
setMethod(
  "..train_naive",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(object, data, ...) {
    # Set a NULL model
    object@model <- NULL

    return(object)
  }
)



#### ..train_naive (familiarModel, NULL) ---------------------------------------
setMethod(
  "..train_naive",
  signature(
    object = "familiarModel",
    data = "NULL"),
  function(object, data, ...) {
    # Set a NULL model.
    object@model <- NULL

    return(object)
  }
)



# ..predict (familiarModel, dataObject) ----------------------------------------
setMethod(
  "..predict",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(object, data, ...) {
    # This is a fall-back option.
    return(get_placeholder_prediction_table(
      object = object,
      data = data))
  }
)



# ..predict (character, dataObject) --------------------------------------------
setMethod(
  "..predict",
  signature(
    object = "character",
    data = "dataObject"),
  function(object, data, ...) {
    # Load object.
    object <- load_familiar_object(object)

    return(do.call(
      ..predict,
      args = c(list(
        "object" = object,
        "data" = data))))
  }
)



# ..predict_survival_probability (familiarModel, dataObject) -------------------
setMethod(
  "..predict_survival_probability",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(object, data, time, ...) {
    # This is a fall-back option.
    return(get_placeholder_prediction_table(
      object = object, 
      data = data, 
      type = "survival_probability"))
  }
)



# ..predict_survival_probability (character, dataObject) -----------------------
setMethod(
  "..predict_survival_probability",
  signature(
    object = "character",
    data = "dataObject"),
  function(object, data, ...) {
    # Load object.
    object <- load_familiar_object(object)

    return(do.call(
      ..predict_survival_probability,
      args = c(
        list(
          "object" = object,
          "data" = data),
        list(...))))
  }
)



# ..set_calibration_info -------------------------------------------------------
setMethod(
  "..set_calibration_info",
  signature(object = "familiarModel"),
  function(object, data) {
    # This is a fall-back option.
    if (is.null(object@calibration_info)) object@calibration_info <- NULL

    return(object)
  }
)



# ..set_recalibration_model ----------------------------------------------------
setMethod(
  "..set_recalibration_model",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(object, data) {
    # This is a fall-back option.
    object@calibration_model <- NULL

    return(object)
  }
)



# ..set_risk_stratification_thresholds -----------------------------------------
setMethod(
  "..set_risk_stratification_thresholds",
  signature(
    object = "familiarModel",
    data = "dataObject"),
  function(object, data) {
    
    if (object@outcome_type %in% c("survival", "competing_risk") &&
        model_is_trained(object)) {
      
      object@km_info <- .find_survival_grouping_thresholds(
        object = object,
        data = data)
      
    } else {
      object@km_info <- NULL
    }

    return(object)
  }
)



# ..set_vimp_parameters --------------------------------------------------------
setMethod(
  "..set_vimp_parameters",
  signature(object = "familiarModel"),
  function(object, ...) {
    # This is a fall-back option.
    return(object)
  }
)



# ..vimp -----------------------------------------------------------------------
setMethod(
  "..vimp",
  signature(object = "familiarModel"),
  function(object, ...) {
    # This is a fall-back option.
    return(get_placeholder_vimp_table(
      vimp_method = object@learner,
      run_table = object@run_table))
  }
)



# trim_model (familiarModel)----------------------------------------------------
setMethod(
  "trim_model",
  signature(object = "familiarModel"),
  function(object, timeout = 60000, ...) {
    # Do not trim the model if there is nothing to trim.
    if (!model_is_trained(object)) return(object)

    # Trim the model.
    trimmed_object <- .trim_model(object = object)

    # Skip further processing if the model object was not trimmed.
    if (!trimmed_object@is_trimmed) return(object)

    # Go over different functions.
    trimmed_object <- .replace_broken_functions(
      object = object,
      trimmed_object = trimmed_object,
      timeout = timeout)

    return(trimmed_object)
  }
)



# .trim_model (familiarModel)---------------------------------------------------
setMethod(
  ".trim_model",
  signature(object = "familiarModel"),
  function(object, ...) {
    # Default method for models that lack a more specific method.
    return(object)
  }
)



# requires_naive_model ---------------------------------------------------------
setMethod(
  "requires_naive_model",
  signature(object = "familiarModel"),
  function(object, ...) {
    # Determine if the model hyperparameters specify that a naive model should
    # be trained. This obviously does not work if there no hyperparameters to
    # speak of.
    if (!has_optimised_hyperparameters(object = object)) return(FALSE)
    
    # Check if the signature size hyperparameter exists.
    if (is.null(object@hyperparameters$sign_size)) return(FALSE)

    # Check if the no-features variable importance.
    if (object@fs_method %in% .get_available_no_features_vimp_methods()) return(TRUE)

    # Check if the signature size is 0.
    return(all(object@hyperparameters$sign_size == 0))
  }
)



# ..update_warnings-------------------------------------------------------------
setMethod(
  "..update_warnings",
  signature(object = "familiarModel"),
  function(object, messages) {
    # Update warnings attached to the object.
    object@messages$warning <- c(
      object@messages$warning,
      messages)

    return(object)
  }
)


# ..update_errors---------------------------------------------------------------
setMethod(
  "..update_errors",
  signature(object = "familiarModel"),
  function(object, messages) {
    # Update error attached to the object.
    object@messages$error <- c(
      object@messages$error,
      messages
    )

    return(object)
  }
)


# has_calibration_info ---------------------------------------------------------
setMethod(
  "has_calibration_info",
  signature(object = "familiarModel"),
  function(object) {
    return(!is.null(object@calibration_info))
  }
)


# set_signature (familiarModel)-------------------------------------------------
setMethod(
  "set_signature",
  signature(object = "familiarModel"),
  function(
    object, 
    rank_table = NULL, 
    signature_features = NULL, 
    minimise_footprint = FALSE, 
    ...) {
    if (is.null(signature_features)) {
      # Get signature features using the table with ranked features. Those
      # features may be clustered.
      signature_features <- get_signature(
        object = object,
        rank_table = rank_table)
    }

    # Find important features, i.e. those that constitute the signature either
    # individually or as part of a cluster.
    model_features <- get_model_features(
      x = signature_features,
      is_clustered = TRUE,
      feature_info_list = object@feature_info)

    # Find novelty features.
    novelty_features <- find_novelty_features(
      model_features = model_features,
      feature_info_list = object@feature_info)

    if (minimise_footprint) {
      # Find only features that are required for running the model.
      required_features <- union(model_features, novelty_features)
      
    } else {
      # Find features that are required for processing the data.
      required_features <- get_required_features(
        x = union(model_features, novelty_features),
        is_clustered = FALSE,
        feature_info_list = object@feature_info)
    }

    # Select only necessary feature info objects.
    available_feature_info <- names(object@feature_info) %in% required_features
    object@feature_info <- object@feature_info[available_feature_info]

    # Set feature-related attribute slots
    object@required_features <- required_features
    object@model_features <- model_features
    object@novelty_features <- novelty_features

    return(object)
  }
)



# get_signature (familiarModel)-------------------------------------------------
setMethod(
  "get_signature",
  signature(object = "familiarModel"),
  function(
    object,
    rank_table = NULL, 
    ...) {
    # Attempt to get signature directly from the object.
    if (!is_empty(object@model_features)) {
      return(features_after_clustering(
        features = object@model_features,
        feature_info_list = object@feature_info))
    }

    # Get signature based on the stored feature information.
    return(do.call(
      get_signature,
      args = list(
        "object" = object@feature_info,
        "vimp_method" = object@fs_method,
        "parameter_list" = object@hyperparameters,
        "rank_table" = rank_table)))
  }
)



# get_signature (list)----------------------------------------------------------
setMethod(
  "get_signature",
  signature(object = "list"),
  function(
    object,
    vimp_method,
    parameter_list,
    rank_table, 
    ...) {
    # Suppress NOTES due to non-standard evaluation in data.table
    name <- rank <- NULL

    # Get signature size
    if (is_empty(parameter_list$sign_size)) {
      signature_size <- 0
    } else {
      signature_size <- parameter_list$sign_size
    }

    # Find features that are pre-assigned to the signature.
    signature_features <- names(object)[sapply(object, is_in_signature)]

    if (vimp_method %in% .get_available_signature_only_vimp_methods()) {
      # Only select signature.
      if (length(signature_features) == 0) stop("No signature was provided.")

      selected_features <- signature_features
      
    } else if (vimp_method %in% .get_available_none_vimp_methods()) {
      # Select all features.
      selected_features <- features_after_clustering(
        features = get_available_features(feature_info_list = object),
        feature_info_list = object)

      # Order randomly so that there is no accidental dependency on order.
      selected_features <- fam_sample(
        x = selected_features,
        size = length(selected_features),
        replace = FALSE)
      
    } else if (vimp_method %in% .get_available_random_vimp_methods()) {
      # Select all features.
      selected_features <- features_after_clustering(
        features = get_available_features(feature_info_list = object),
        feature_info_list = object)

      # Shrink signature sizes that are too large.
      if (signature_size > length(selected_features)) {
        signature_size <- length(selected_features)
      }

      # Randomly pick the signature.
      selected_features <- fam_sample(
        x = selected_features,
        size = signature_size,
        replace = FALSE)
      
    } else if (vimp_method %in% .get_available_no_features_vimp_methods()) {
      # No features are selected.
      selected_features <- NULL
      
    } else {
      # Select signature and any additional features according to rank.
      selected_features <- signature_features

      # Get number remaining available features
      n_allowed_features <- signature_size - length(signature_features)

      # Check that features may be added, and the rank table is not empty.
      if (n_allowed_features > 0 && !is_empty(rank_table)) {
        # Get available features.
        features <- features_after_clustering(
          features = get_available_features(feature_info_list = object),
          feature_info_list = object)

        # Remove signature features, if any, to prevent duplicates.
        features <- setdiff(features, signature_features)

        # Keep only feature ranks of feature corresponding to available
        # features, and order by rank.
        rank_table <- rank_table[name %in% features, ][order(rank)]

        # Add good features (low rank) to the selection
        selected_features <- c(
          signature_features,
          head(x = rank_table, n = n_allowed_features)$name)
      }
    }

    return(selected_features)
  }
)



.get_available_risklike_prediction_types <- function() {
  return(c("hazard_ratio", "cumulative_hazard", "survival_probability"))
}

Try the familiar package in your browser

Any scripts or data that you put into this service are public.

familiar documentation built on Sept. 30, 2024, 9:18 a.m.