R/surv_ceteris_paribus.R

Defines functions calculate_variable_survival_profile.default calculate_variable_survival_profile calculate_variable_split.default calculate_variable_split surv_ceteris_paribus.default surv_ceteris_paribus.surv_explainer surv_ceteris_paribus

Documented in surv_ceteris_paribus surv_ceteris_paribus.surv_explainer

#' Helper functions for `predict_profile.R`
#' @rdname surv_ceteris_paribus
#' @keywords internal
surv_ceteris_paribus <- function(x, ...) UseMethod("surv_ceteris_paribus", x)

#' Helper functions for `predict_profile.R`
#'
#' @rdname surv_ceteris_paribus
#'
#' @param x an explainer object - model preprocessed by the `explain()` function
#' @param new_observation a new observation for which predictions need to be explained
#' @param variables character, names of the variables to be included in the calculations
#' @param categorical_variables character vector, names of variables that should be treated as categories (factors are included by default)
#' @param variable_splits named list of splits for variables, in most cases created with internal functions. If NULL then it will be calculated based on validation data available in the explainer
#' @param grid_points maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default `101`.
#' @param variable_splits_type character, decides how variable grids should be calculated. Use `"quantiles"` for percentiles or `"uniform"` (default) to get uniform grid of points.
#' @param ... other parameters, currently ignored
#'
#' @return A data.frame containing the result of the calculation.
#'
#' @keywords internal
surv_ceteris_paribus.surv_explainer <- function(x,
                                                new_observation,
                                                variables = NULL,
                                                categorical_variables = NULL,
                                                variable_splits = NULL,
                                                grid_points = 101,
                                                variable_splits_type = "uniform",
                                                center = FALSE,
                                                output_type = "survival",
                                                ...) {
    test_explainer(x, has_data = TRUE, has_survival = TRUE, has_y = TRUE, function_name = "ceteris_paribus_survival")
    data <- x$data
    model <- x$model
    label <- x$label
    if (output_type == "survival"){
        predict_survival_function <- x$predict_survival_function
    } else {
        predict_survival_function <- x$predict_cumulative_hazard_function
    }

    times <- x$times

    surv_ceteris_paribus.default(
        x = model,
        data = data,
        predict_survival_function = predict_survival_function,
        new_observation = new_observation,
        variables = variables,
        categorical_variables = categorical_variables,
        variable_splits = variable_splits,
        grid_points = grid_points,
        variable_splits_type = variable_splits_type,
        variable_splits_with_obs = TRUE,
        center = center,
        label = label,
        times = times,
        ...
    )
}

surv_ceteris_paribus.default <- function(x,
                                         data,
                                         predict_survival_function = NULL,
                                         new_observation,
                                         variables = NULL,
                                         categorical_variables = NULL,
                                         variable_splits = NULL,
                                         grid_points = 101,
                                         variable_splits_type = "uniform",
                                         variable_splits_with_obs = TRUE,
                                         center = center,
                                         label = NULL,
                                         times = times,
                                         ...) {
    if (is.data.frame(data)) {
        common_variables <- intersect(colnames(new_observation), colnames(data))
        new_observation <- new_observation[, common_variables, drop = FALSE]
        data <- data[, common_variables, drop = FALSE]
    }

    # change categorical_features to column names
    if (is.numeric(categorical_variables)) categorical_variables <- colnames(data)[categorical_variables]
    additional_categorical_variables <- categorical_variables
    factor_variables <- colnames(data)[sapply(data, is.factor)]
    categorical_variables <- unique(c(additional_categorical_variables, factor_variables))

    if (is.null(data)) {
        stop("The ceteris_paribus() function requires explainers created with specified 'data'.")
    }

    # calculate splits
    if (is.null(variable_splits)) {
        if (is.null(variables)) {
            variables <- colnames(data)
        }
        variable_splits <- calculate_variable_split(
            data,
            variables = variables,
            categorical_variables = categorical_variables,
            grid_points = grid_points,
            variable_splits_type = variable_splits_type,
            new_observation = if (variable_splits_with_obs) new_observation else NA
        )
    }

    profiles <- calculate_variable_survival_profile(
        new_observation,
        variable_splits,
        x,
        center,
        predict_survival_function,
        times,
        ...
    )

    profiles$`_vtype_` <- ifelse(profiles$`_vname_` %in% categorical_variables, "categorical", "numerical")



    col_yhat <- grep(colnames(profiles), pattern = "^_yhat_")

    attr(profiles, "times") <- times
    attr(profiles, "observations") <- new_observation

    ret <- list(
        eval_times = times,
        variable_values = new_observation,
        result = cbind(profiles, `_label_` = label),
        center = center
    )

    class(ret) <- c("surv_ceteris_paribus", "list")

    ret
}


calculate_variable_split <- function(data, variables = colnames(data), categorical_variables = NULL, grid_points = 101, variable_splits_type = "quantiles", new_observation = NA) {
    UseMethod("calculate_variable_split", data)
}

#' @importFrom stats na.omit quantile
#' @keywords internal
calculate_variable_split.default <- function(data, variables = colnames(data), categorical_variables = NULL, grid_points = 101, variable_splits_type = "quantiles", new_observation = NA) {
    variable_splits <- lapply(variables, function(var) {
        selected_column <- na.omit(data[, var])

        if (!(var %in% categorical_variables)) {
            probs <- seq(0, 1, length.out = grid_points)
            if (variable_splits_type == "quantiles") {
                selected_splits <- unique(quantile(selected_column, probs = probs))
            } else {
                selected_splits <- seq(min(selected_column, na.rm = TRUE), max(selected_column, na.rm = TRUE), length.out = grid_points)
            }
            if (!any(is.na(new_observation))) {
                selected_splits <- sort(unique(c(selected_splits, na.omit(new_observation[, var]))))
            }
        } else {
            if (any(is.na(new_observation))) {
                selected_splits <- sort(unique(selected_column))
            } else {
                selected_splits <- sort(unique(rbind(
                    data[, var, drop = FALSE],
                    new_observation[, var, drop = FALSE]
                )[, 1]))
            }
        }
        selected_splits
    })
    names(variable_splits) <- variables
    variable_splits
}



calculate_variable_survival_profile <- function(data, variable_splits, model, center, predict_survival_function = NULL, times = NULL, ...) {
    UseMethod("calculate_variable_survival_profile")
}

calculate_variable_survival_profile.default <- function(data, variable_splits, model, center, predict_survival_function = NULL, times = NULL, ...) {
    variables <- names(variable_splits)
    prog <- progressr::progressor(along = 1:(length(variables)))

    if (is.null(rownames(data))) {
        ids <- 1:nrow(data) # it never goes here, because null rownames are automatically setted to 1:n
    } else {
        ids <- rownames(data)
    }

    predictions_original <- predict_survival_function(model, data, times)
    mean_pred <- colMeans(predictions_original)

    profiles <- lapply(variables, function(variable) {
        split_points <- variable_splits[[variable]]

        new_data <- data[rep(1:nrow(data), each = length(split_points)), , drop = FALSE]
        new_data[, variable] <- rep(split_points, nrow(data))

        yhat <- c(t(predict_survival_function(model, new_data, times)))
        if (center) {
            yhat <- yhat - mean_pred
        }

        new_data <- data.frame(new_data[rep(seq_len(nrow(new_data)), each = length(times)), ],
            `_times_` = rep(times, times = nrow(new_data)),
            `_yhat_` = yhat,
            `_vname_` = variable,
            `_ids_` = rep(ids, each = length(times) * length(split_points)),
            check.names = FALSE
        )
        prog()
        new_data
    })

    profile <- do.call(rbind, profiles)
    class(profile) <- c("individual_variable_profile", class(profile))
    profile
}

Try the survex package in your browser

Any scripts or data that you put into this service are public.

survex documentation built on Oct. 25, 2023, 1:06 a.m.