R/survival_forest.R

Defines functions augment.ranger_survival_exploratory survival_rate_to_predicted_time survival_time_to_predicted_rate glance.ranger_survival_exploratory tidy.ranger_survival_exploratory exp_survival_forest partial_dependence.ranger_survival_exploratory calc_survival_curves_with_strata calc_permutation_importance_ranger_survival calc_mean_survival extract_survival_rate_at

# Extracts survival rates at specific time from survival curves.
# survival - Matrix of survival curves. Rows represents observations and columns represents points of time.
# unique_death_times - Numeric vector of unique death times.
# at - Time at which the survival rates should be extracted from survival curves.
extract_survival_rate_at <- function(survival, unique_death_times, at) {
  index <- sum(unique_death_times <= at)
  survival[, index]
}

# Calculate weighted mean survival time from predicted survival curve.
calc_mean_survival <- function(survival, unique_death_times) {
  # If the survival rate at the last unique death time is not 0,
  # we assume that the survivers lived one more term, just so that we can calculate the finite mean.
  # Thus the unique.death.times is appended with max(unique.death.times)+1.
  ret <- -matrixStats::rowDiffs(cbind(1,survival,0)) %*% c(unique_death_times, max(unique_death_times)+1)
  ret
}

# Permutation importance. The one from ranger seems too unstable for our use. Maybe it's based on OOB prediction?
calc_permutation_importance_ranger_survival <- function(fit, time_col, status_col, vars, data) {
  var_list <- as.list(vars)
  importances <- purrr::map(var_list, function(var) {
    mmpf::permutationImportance(data, vars=var, y=time_col, model=fit, nperm=5, # Since the result seems too unstable with nperm=1, where we use elsewhere, here we use 5.
                                predict.fun = function(object, newdata) {
                                  predicted <- predict(object,data=newdata)
                                  # Use the weighted mean predicted survival time as the predicted value. To evaluate prediction performance, we will later calcualte concordance based on it.
                                  mean_survival <- calc_mean_survival(predicted$survival, predicted$unique.death.times)
                                  tibble::tibble(x=mean_survival,status=newdata[[status_col]])
                                },
                                loss.fun = function(x,y){ # Use 1 - concordance as loss function.
                                  df <- x %>% dplyr::mutate(time=!!y[[1]])
                                  1 - survival::concordance(survival::Surv(time, status)~x,data=df)$concordance
                                })
  })
  importances <- purrr::flatten_dbl(importances)
  importances_df <- tibble::tibble(variable=vars, importance=pmax(importances, 0)) # Show 0 for negative importance, which can be caused by chance in case of permutation importance.
  importances_df <- importances_df %>% dplyr::arrange(-importance)
  importances_df
}

# Calculates survival curves with Kaplan-Meier with strata specified by vars argument.
calc_survival_curves_with_strata <- function(df, time_col, status_col, vars, orig_predictor_classes) {
  # Create map from variable name to chart type. TODO: Eliminate duplicated code.
  chart_type_map <- c()
  x_type_map <-c()
  for(col in colnames(df)) {
    if (is.numeric(df[[col]])) {
      x_type <- "numeric"
      chart_type <- "line"
    }
    # Since we turn logical into factor in preprocess_regression_data_after_sample(), detect them accordingly.
    else if (is.factor(df[[col]]) && all(levels(df[[col]]) %in% c("TRUE", "FALSE", "(Missing)"))) {
      x_type <- "logical"
      chart_type <- "scatter"
    }
    # Since we turn character into factor in preprocess_regression_data_after_sample(), look at the recorded original class.
    else if (!is.null(orig_predictor_classes) && "factor" %in% orig_predictor_classes[[col]]) {
      x_type <- "factor"
      chart_type <- "scatter"
    }
    else {
      x_type <- "character" # Since we turn charactors into factor in preprocessing (fct_lump) and cannot distinguish the original type at this point, for now, we treat both factors and characters as "characters" here.
      chart_type <- "scatter"
    }
    x_type_map <- c(x_type_map, x_type)
    chart_type_map <- c(chart_type_map, chart_type)
  }
  names(x_type_map) <- colnames(df)
  names(chart_type_map) <- colnames(df)

  vars_list <- as.list(vars)
  curve_dfs_list <- purrr::map(vars_list, function(var) {
    if (is.numeric(df[[var]])) { # categorize numeric column into 10 bins.
      grouped <- df %>% dplyr::mutate(.temp.bin.column=cut(!!rlang::sym(var), breaks=10)) %>% dplyr::group_by(.temp.bin.column)
      df <- grouped %>% dplyr::mutate(!!rlang::sym(var):=mean(!!rlang::sym(var), na.rm=TRUE))
    }
    fml <- as.formula(paste0("survival::Surv(`", time_col, "`,`", status_col, "`) ~ `", var, "`"))
    fit <- survival::survfit(fml, data = df)
    ret <- broom:::tidy.survfit(fit)
    ret
  })
  # Bind the data frames in the list. fill=TRUE is needed because it is possible that strata column is missing for some of the data frames
  # if var has only one value for all the rows.
  ret <- data.table::rbindlist(curve_dfs_list, fill=TRUE)
  ret <- ret %>% tidyr::separate(strata, into = c('variable','value'), sep='=') %>% rename(survival=estimate, period=time)
  ret <- ret %>% dplyr::mutate(chart_type = chart_type_map[variable])
  ret <- ret %>% dplyr::mutate(x_type = x_type_map[variable])
  ret
}

# Calculates partial dependence. The output is survival curves at different values of the specified variables.
partial_dependence.ranger_survival_exploratory <- function(fit, time_col, vars = colnames(data),
  n = c(min(nrow(unique(data[, vars, drop = FALSE])), 25L), nrow(data)), # Keeping same default of 25 as edarf::partial_dependence, although we usually overwrite from callers.
  interaction = FALSE, uniform = TRUE, data, ...) {
  times <- sort(unique(data[[time_col]])) # Keep vector of actual times to map time index to actual time later.
  # Add time 0 if times does not start with 0.
  time_zero_added <- FALSE
  if (times[1] != 0) {
    times <- c(0, times)
    time_zero_added <- TRUE
  }

  predict.fun <- function(object, newdata) {
    res <- predict(object, data=newdata)$survival
    # If time 0 is added, add the column of 100% survival.
    if (time_zero_added) {
      res <- cbind(rep(1,nrow(res)),res)
    }
    res
  }

  aggregate.fun <- function(x) {
    mean(x)
  }

  args = list(
    "data" = data,
    "vars" = vars,
    "n" = n,
    "model" = fit,
    "uniform" = uniform,
    "predict.fun" = predict.fun,
    "aggregate.fun" = aggregate.fun,
    ...
  )
  
  if (length(vars) > 1L & !interaction) { # More than one variables are there. Iterate calling mmpf::marginalPrediction.
    pd = data.table::rbindlist(sapply(vars, function(x) {
      args$vars = x
      if ("points" %in% names(args))
        args$points = args$points[x]
      if (!is.numeric(data[[x]])) { # If categorical, cover all categories in the data.
        n_tmp <- args$n
        n_tmp[1] <- length(unique(data[[x]]))
        args$n <- n_tmp
      }
      else if (all(data[[x]] %% 1 == 0)) { # Adjust for integer with a few unique values
        n_uniq <- max(data[[x]]) - min(data[[x]]) + 1
        if (n_uniq <= 5) {
          n_tmp <- args$n
          n_tmp[1] <- n_uniq
          args$n <- n_tmp
        }
      }
      mp = do.call(mmpf::marginalPrediction, args)
      mp
    }, simplify = FALSE), fill = TRUE)
    data.table::setcolorder(pd, c(vars, colnames(pd)[!colnames(pd) %in% vars]))
  } else {
    if (!is.numeric(data[[vars]])) { # If categorical, cover all categories in the data.
      n_tmp <- args$n
      n_tmp[1] <- length(unique(data[[vars]]))
      args$n <- n_tmp
    }
    else if (all(data[[vars]] %% 1 == 0)) { # Adjust for integer with a few unique values
      n_uniq <- max(data[[vars]]) - min(data[[vars]]) + 1
      if (n_uniq <= 5) {
        n_tmp <- args$n
        n_tmp[1] <- n_uniq
        args$n <- n_tmp
      }
    }
    pd = do.call(mmpf::marginalPrediction, args)
  }

  attr(pd, "class") = c("pd", "data.frame")
  attr(pd, "interaction") = interaction == TRUE
  attr(pd, "vars") = vars
  # Format of pd looks like this:
  #        trt age        V1        V2        V3        V4        V5        V6        V7        V8        V9       V10
  # 1 1.000000  NA 0.9984156 0.9971480 0.9867692 0.9843847 0.9631721 0.9357729 0.9165525 0.8988693 0.8679139 0.8532433
  # 2 1.111111  NA 0.9984156 0.9971480 0.9867692 0.9843847 0.9631721 0.9357729 0.9165525 0.8988693 0.8679139 0.8532433
  # 3 1.222222  NA 0.9984156 0.9971480 0.9867692 0.9843847 0.9631721 0.9357729 0.9165525 0.8988693 0.8679139 0.8532433
  ret <- pd %>% tidyr::pivot_longer(matches('^V[0-9]+$'),names_to = 'period', values_to = 'survival')
  ret <- ret %>% dplyr::mutate(period = as.numeric(stringr::str_remove(period,'^V')))
  # Format of ret looks like this:
  #     trt   age period survival
  #   <dbl> <dbl>  <dbl>    <dbl>
  # 1     1    NA      1    0.997
  # 2     1    NA      2    0.995
  # 3     1    NA      3    0.984
  # 4     1    NA      4    0.981

  chart_type_map <- c()
  for(col in colnames(ret)) {
    chart_type_map <- c(chart_type_map, is.numeric(ret[[col]]))
  }
  chart_type_map <- ifelse(chart_type_map, "line", "scatter")
  names(chart_type_map) <- colnames(ret)

  # To avoid "Error: Can't convert <double> to <character>." from pivot_longer we need to use values_transform rather than values_ptype. https://github.com/tidyverse/tidyr/issues/980
  ret <- ret %>% tidyr::pivot_longer(c(-period, -survival) ,names_to = 'variable', values_to = 'value', values_transform = list(value=as.character), values_drop_na=TRUE)
  # Format of ret looks like this:
  #   period survival variable value
  #   <chr>     <dbl> <chr>    <dbl>
  # 1 1         0.997 trt          1
  # 2 2         0.995 trt          1
  # 3 3         0.984 trt          1
  # 4 4         0.981 trt          1
  # 5 5         0.963 trt          1
  # 6 6         0.938 trt          1
  # 7 7         0.931 trt          1
  # 8 8         0.920 trt          1
  # 9 9         0.902 trt          1
  #10 10        0.896 trt          1
  ret <- ret %>%  dplyr::mutate(period = (!!times)[period]) # Map back period from index to actual time.
  ret <- ret %>%  dplyr::mutate(chart_type = chart_type_map[variable])
  ret
}

#' builds cox model quickly by way of sampling or fct_lumn, for analytics view.
#' @export
exp_survival_forest <- function(df,
                    time,
                    status,
                    ...,
                    start_time = NULL,
                    end_time = NULL,
                    time_unit = "day",
                    predictor_funs = NULL,
                    max_nrow = 50000, # With 50000 rows, taking 6 to 7 seconds on late-2016 Macbook Pro.
                    max_sample_size = NULL, # Half of max_nrow.
                    ntree = 20,
                    nodesize = 12,
                    predictor_n = 12, # so that at least months can fit in it.
                    importance_measure = "permutation", # "permutation" or "impurity".
                    max_pd_vars = NULL,
                    pd_sample_size = 500,
                    pred_survival_time = NULL,
                    pred_survival_threshold = 0.5,
                    predictor_outlier_filter_type = NULL,
                    predictor_outlier_filter_threshold = NULL,
                    seed = 1,
                    test_rate = 0.0,
                    test_split_type = "random" # "random" or "ordered"
                    ){
  # TODO: cleanup code only aplicable to randomForest. this func was started from copy of calc_feature_imp, and still adjusting for lm. 

  if (importance_measure == "permutation") { # For permutation importance, we turn off ranger's importance calculation adn use our own implementation.
    importance_measure_ranger <- "none"
  }
  else {
    importance_measure_ranger <- importance_measure
  }

  # using the new way of NSE column selection evaluation
  # ref: http://dplyr.tidyverse.org/articles/programming.html
  # ref: https://github.com/tidyverse/tidyr/blob/3b0f946d507f53afb86ea625149bbee3a00c83f6/R/spread.R
  time_col <- tidyselect::vars_select(names(df), !! rlang::enquo(time))
  start_time_col <- NULL
  end_time_col <- NULL
  if (length(time_col) == 0) { # This means time was NULL
    start_time_col <- tidyselect::vars_select(names(df), !! rlang::enquo(start_time))
    end_time_col <- tidyselect::vars_select(names(df), !! rlang::enquo(end_time))
    time_unit_days <- get_time_unit_days(time_unit, df[[start_time_col]], df[[end_time_col]])
    time_unit <- attr(time_unit_days, "label") # Get label like "day", "week".
    # We are ceiling survival time to make it integer in the specified time unit, just like we do for Survival Curve analytics view.
    # This is to make resulting survival curve to have integer data point in the specified time unit.
    # This also results in faster training of survival forest, since it reduces the number of distinct death times.
    # Though, it does give undesirable bias to longer survival time when the time unit is larger.
    df <- df %>% dplyr::mutate(.time = ceiling(as.numeric(!!rlang::sym(end_time_col) - !!rlang::sym(start_time_col), units = "days")/time_unit_days))
    time_col <- ".time"
  }
  status_col <- tidyselect::vars_select(names(df), !! rlang::enquo(status))
  # this evaluates select arguments like starts_with
  orig_selected_cols <- tidyselect::vars_select(names(df), !!! rlang::quos(...))

  if (!is.null(predictor_funs)) {
    df <- df %>% mutate_predictors(orig_selected_cols, predictor_funs)
    selected_cols <- names(unlist(predictor_funs))
  }
  else {
    selected_cols <- orig_selected_cols
  }
  # Sort predictors so that the result of permutation importance is stable against change of column order.
  selected_cols <- stringr::str_sort(selected_cols)

  grouped_cols <- grouped_by(df)

  # remove grouped col or time/status col
  selected_cols <- setdiff(selected_cols, c(grouped_cols, time_col, status_col))

  if(test_rate < 0 | 1 < test_rate){
    stop("test_rate must be between 0 and 1")
  } else if (test_rate == 1){
    stop("test_rate must be less than 1")
  }

  if (any(c(time_col, status_col, selected_cols) %in% grouped_cols)) {
    stop("grouping column is used as variable columns")
  }

  if (predictor_n < 2) {
    stop("Max # of categories for explanatory vars must be at least 2.")
  }

  # check status_col.
  if (!is.logical(df[[status_col]])) {
    stop(paste0("Status column (", status_col, ")  must be logical."))
  }

  # check time_col
  if (!is.numeric(df[[time_col]])) {
    stop(paste0("Time column (", time_col, ") must be numeric"))
  }

  if (is.null(pred_survival_time)) {
    # By default, use mean of observations with event.
    # median gave a point where survival rate was still predicted 100% in one of our test case.
    pred_survival_time <- mean((df %>% dplyr::filter(!!rlang::sym(status_col)))[[time_col]], na.rm=TRUE)
    # Pick maximum of existing values equal or less than the actual mean.
    pred_survival_time <- max((df %>% dplyr::filter(!!rlang::sym(time_col) <= !!pred_survival_time))[[time_col]], na.rm=TRUE)
    # Since 0 is most likely meaningless as a prediction time, pick 1 if the above gives 0.
    if (pred_survival_time == 0) pred_survival_time <- 1
  }

  # cols will be filtered to remove invalid columns
  cols <- selected_cols

  for (col in selected_cols) {
    if(all(is.na(df[[col]]))){
      # remove columns if they are all NA
      cols <- setdiff(cols, col)
      df[[col]] <- NULL # drop the column so that SMOTE will not see it. 
    }
  }

  # randomForest fails if columns are not clean. TODO is this needed?
  #clean_df <- janitor::clean_names(df)
  clean_df <- df # turn off clean_names for lm
  # Replace column names with names like c1_, c2_...
  # _ is so that name part and value part of categorical coefficient can be separated later,
  # even with values that starts with number like "9E".
  names(clean_df) <- paste0("c",1:length(colnames(clean_df)), "_")
  name_map <- colnames(clean_df)
  names(name_map) <- colnames(df)

  # clean_names changes column names
  # without chaning grouping column name
  # information in the data frame
  # and it causes an error,
  # so the value of grouping columns
  # should be still the names of grouping columns
  name_map[grouped_cols] <- grouped_cols
  colnames(clean_df) <- name_map

  clean_status_col <- name_map[status_col]
  clean_time_col <- name_map[time_col]
  clean_cols <- name_map[cols]
  clean_start_time_col <- NULL
  clean_end_time_col <- NULL
  if (!is.null(start_time_col)) {
    clean_start_time_col <- name_map[start_time_col]
    clean_end_time_col <- name_map[end_time_col]
  }

  each_func <- function(df) {
    tryCatch({
      if(!is.null(seed)){
        set.seed(seed)
      }

      df <- df %>%
        dplyr::filter(!is.na(df[[clean_status_col]])) # this form does not handle group_by. so moved into each_func from outside.

      df <- preprocess_regression_data_before_sample(df, clean_time_col, clean_cols)
      clean_cols <- attr(df, 'predictors') # predictors are updated (removed) in preprocess_pre_sample. Catch up with it.

      # sample the data for performance if data size is too large.
      sampled_nrow <- NULL
      if (!is.null(max_nrow) && nrow(df) > max_nrow) {
        # Record that sampling happened.
        sampled_nrow <- max_nrow
        df <- df %>% sample_rows(max_nrow)
      }

      # Remove outliers if specified so.
      # This has to be done before preprocess_regression_data_after_sample, since it can remove rows and reduce number of unique values,
      # just like sampling.
      df <- remove_outliers_for_regression_data(df, clean_time_col, clean_cols,
                                                NULL, #target_outlier_filter_type
                                                NULL, #target_outlier_filter_threshold
                                                predictor_outlier_filter_type,
                                                predictor_outlier_filter_threshold)

      # Temporarily remove unused columns for uniformity. TODO: Revive them when we do that across the product.
      clean_cols_without_names <- clean_cols
      names(clean_cols_without_names) <- NULL # remove names to eliminate renaming effect of select.
      if (is.null(clean_start_time_col)) {
        df <- df %>% dplyr::select(!!!rlang::syms(clean_cols_without_names), !!rlang::sym(clean_time_col), rlang::sym(clean_status_col))
      }
      else {
        df <- df %>% dplyr::select(!!!rlang::syms(clean_cols_without_names), !!rlang::sym(clean_start_time_col), !!rlang::sym(clean_end_time_col), !!rlang::sym(clean_time_col), rlang::sym(clean_status_col))
      }

      # Capture the classes of the columns at this point before preprocess_regression_data_after_sample,
      # so that we know the original classes of columns before characters are turned into factors,
      # so that we can sort the partial dependence data for display accordingly.
      # preprocess_regression_data_after_sample can remove columns, but it should not cause problem that we have more columns in
      # orig_predictor_classes than the partial dependence data.
      # Also, preprocess_regression_data_after_sample has code to add columns extracted from Date/POSIXct, but with recent releases,
      # that should not happen, since the extraction is already done by mutate_predictors.
      orig_predictor_classes <- capture_df_column_classes(df, clean_cols)
      df <- preprocess_regression_data_after_sample(df, clean_time_col, clean_cols, predictor_n = predictor_n, name_map = name_map)
      c_cols <- attr(df, 'predictors') # predictors are updated (added and/or removed) in preprocess_post_sample. Catch up with it.
      name_map <- attr(df, 'name_map')

      source_data <- df

      # split training and test data
      test_index <- sample_df_index(source_data, rate = test_rate, ordered = (test_split_type == "ordered"))
      df <- safe_slice(source_data, test_index, remove = TRUE)
      if (test_rate > 0) {
        df_test <- safe_slice(source_data, test_index, remove = FALSE)
        unknown_category_rows_index_vector <- get_unknown_category_rows_index_vector(df_test, df)
        df_test <- df_test[!unknown_category_rows_index_vector, , drop = FALSE] # 2nd arg must be empty.
        unknown_category_rows_index <- get_row_numbers_from_index_vector(unknown_category_rows_index_vector)
      }

      # build formula for survival forest.
      rhs <- paste0("`", c_cols, "`", collapse = " + ")
      fml <- as.formula(paste0("survival::Surv(`", clean_time_col, "`, `", clean_status_col, "`) ~ ", rhs))
      # all or max_sample_size data will be used for randomForest
      # to grow a tree
      if (is.null(max_sample_size)) { # default to half of max_nrow
        max_sample_size = max_nrow/2
      }
      sample.fraction <- min(c(max_sample_size / max_nrow, 1))
      model <- ranger::ranger(fml, data = df, importance = importance_measure_ranger,
        num.trees = ntree,
        min.node.size = nodesize,
        keep.inbag=TRUE,
        sample.fraction = sample.fraction)
      # these attributes are used in tidy of randomForest TODO: is this good for lm too?
      model$terms_mapping <- names(name_map)
      names(model$terms_mapping) <- name_map
      model$sampled_nrow <- sampled_nrow

      if (length(c_cols) > 1) { # Show importance only when there are multiple variables.
        # get importance to decide variables for partial dependence
        if (importance_measure != "permutation") {
          imp <- ranger::importance(model)
          imp_df <- tibble::tibble( # Use tibble since data.frame() would make variable factors, which breaks things in following steps.
            variable = names(imp),
            importance = imp
            ) %>% dplyr::arrange(-importance)
        }
        else { # For permutation, we use our own implementation. The one from ranger seems too unstable for our use. Maybe it's based on OOB prediction?
          imp_df <- calc_permutation_importance_ranger_survival(model, clean_time_col, clean_status_col, c_cols, df)
        }
        model$imp_df <- imp_df
        imp_vars <- imp_df$variable
      }
      else {
        error <- simpleError("Variable importance requires two or more variables.")
        model$imp_df <- error
        imp_vars <- c_cols
      }

      if (is.null(max_pd_vars)) {
        max_pd_vars <- 20 # Number of most important variables to calculate partial dependences on. This used to be 12 but we decided it was a little too small.
      }
      imp_vars <- imp_vars[1:min(length(imp_vars), max_pd_vars)] # take max_pd_vars most important variables
      model$imp_vars <- imp_vars
      model$orig_predictor_classes <- orig_predictor_classes
      model$partial_dependence <- partial_dependence.ranger_survival_exploratory(model, clean_time_col, vars = imp_vars, n = c(9, min(nrow(df), pd_sample_size)), data = df) # grid of 9 is convenient for both PDP and survival curves.
      model$pred_survival_time <- pred_survival_time
      model$pred_survival_threshold <- pred_survival_threshold
      model$survival_curves <- calc_survival_curves_with_strata(df, clean_time_col, clean_status_col, imp_vars, orig_predictor_classes)

      # Calculate concordance.
      # Concordance by model$survival is too bad, most likely because it is out-of-bag prediction. We explictly predict with training data to calculate training concordance.
      prediction_training <- predict(model, data=df)
      model$prediction_training <- prediction_training
      concordance_df <- tibble::tibble(x=calc_mean_survival(prediction_training$survival, prediction_training$unique.death.times), time=df[[clean_time_col]], status=df[[clean_status_col]])

      # The concordance is (d+1)/2, where d is Somers' d. https://cran.r-project.org/web/packages/survival/vignettes/concordance.pdf
      model$concordance <- survival::concordance(survival::Surv(time, status)~x,data=concordance_df)

      model$auc <- survival_auroc(extract_survival_rate_at(prediction_training$survival, prediction_training$unique.death.times, pred_survival_time),
                               df[[clean_time_col]], df[[clean_status_col]], pred_survival_time,
                               revert=TRUE)

      # Add cleaned column names. Used for prediction of survival rate on the specified day, and date the survival probability drops to the specified rate. 
      model$clean_start_time_col <- clean_start_time_col
      model$clean_end_time_col <- clean_end_time_col
      model$clean_time_col <- clean_time_col
      model$clean_status_col <- clean_status_col
      model$time_unit <- time_unit

      if (test_rate > 0) {
        df_test_clean <- cleanup_df_for_test(df_test, df, c_cols)
        na_row_numbers_test <- attr(df_test_clean, "na_row_numbers")
        unknown_category_rows_index <- attr(df_test_clean, "unknown_category_rows_index")

        prediction_test <- predict(model, data=df_test_clean)
        # TODO: Following current convention for the name na.action to keep na row index, but we might want to rethink.
        # We do not keep this for training since na.roughfix should fill values and not delete rows. TODO: Is this comment valid here for survival forest?
        attr(prediction_test, "na.action") <- na_row_numbers_test
        attr(prediction_test, "unknown_category_rows_index") <- unknown_category_rows_index
        model$prediction_test <- prediction_test
        model$df_test <- df_test_clean

        # Calculate concordance.
        concordance_df_test <- tibble::tibble(x=calc_mean_survival(prediction_test$survival, prediction_test$unique.death.times), time=df_test_clean[[clean_time_col]], status=df_test_clean[[clean_status_col]])
        # The concordance is (d+1)/2, where d is Somers' d. https://cran.r-project.org/web/packages/survival/vignettes/concordance.pdf
        model$concordance_test <- survival::concordance(survival::Surv(time, status)~x,data=concordance_df_test)

        model$auc_test <- survival_auroc(extract_survival_rate_at(prediction_test$survival, prediction_test$unique.death.times, pred_survival_time),
                                      df_test_clean[[clean_time_col]], df_test_clean[[clean_status_col]], pred_survival_time,
                                      revert=TRUE)

        model$test_nevent <- sum(df_test_clean[[clean_status_col]], na.rm=TRUE)
      }
      model$df <- df
      model$nevent <- sum(df[[clean_status_col]], na.rm=TRUE)

      if (!is.null(predictor_funs)) {
        model$orig_predictor_cols <- orig_selected_cols
        attr(predictor_funs, "LC_TIME") <- Sys.getlocale("LC_TIME")
        attr(predictor_funs, "sysname") <- Sys.info()[["sysname"]] # Save platform name (e.g. Windows) since locale name might need conversion for the platform this model will be run on.
        attr(predictor_funs, "lubridate.week.start") <- getOption("lubridate.week.start")
        model$predictor_funs <- predictor_funs
      }

      # add special lm_coxph class for adding extra info at glance().
      class(model) <- c("ranger_survival_exploratory", class(model))
      model
    }, error = function(e){
      if(length(grouped_cols) > 0) {
        # In repeat-by case, we report group-specific error in the Summary table,
        # so that analysis on other groups can go on.
        class(e) <- c("ranger_survival_exploratory", class(e))
        e
      } else {
        stop(e)
      }
    })
  }

  ret <- do_on_each_group(clean_df, each_func, name = "model", with_unnest = FALSE)
  # Pass down survival time used for prediction. This is for the post-processing for time-dependent ROC.
  attr(ret, "pred_survival_time") <- pred_survival_time
  # Pass down time unit for prediction step UI.
  attr(ret, "time_unit") <- time_unit
  ret
}

#' special version of tidy.coxph function to use with build_coxph.fast.
#' @export
tidy.ranger_survival_exploratory <- function(x, type = 'importance', ...) { #TODO: add test
  if ("error" %in% class(x)) {
    ret <- data.frame()
    return(ret)
  }
  switch(type,
    importance = {
      if (is.null(x$imp_df) || "error" %in% class(x$imp_df)) {
        # Permutation importance is not supported for the family and link function, or skipped because there is only one variable.
        # Return empty data.frame to avoid error.
        ret <- data.frame()
        return(ret)
      }
      ret <- x$imp_df
      ret <- ret %>% dplyr::mutate(variable = x$terms_mapping[variable]) # map variable names to original.
      ret
    },
    partial_dependence_survival_curve = {
      ret <- x$partial_dependence
      ret <- ret %>% dplyr::group_by(variable) %>% tidyr::nest() %>%
        dplyr::mutate(data = purrr::map(data,function(df){ # Show only 5 lines out of 9 lines for survival curve.
          if (df$chart_type[[1]] == 'line') {
            ret <- df %>% dplyr::mutate(value_index=as.integer(forcats::fct_inorder(value)))
            # %% 2 is to show only 5 lines out of 9 lines for survival curves variated by a numeric variable.
            if (max(ret$value_index) >= 9) { # It is possible that value index is less than 9 for integer with 5 or less unique values.
              ret <- ret %>% dplyr::filter(value_index %% 2 == 1) %>% dplyr::mutate(value_index=ceiling(value_index/2))
            }
            ret
          }
          else {
            df %>% dplyr::mutate(value_index=as.integer(forcats::fct_inorder(value))) %>% dplyr::mutate(value_index=value_index+5)
          }
        })) %>% tidyr::unnest() %>% dplyr::ungroup() %>% dplyr::mutate(value_index=factor(value_index)) # Make value_index a factor to control color.

      # Reduce number of unique x-axis values for better chart drawing performance, and not to overflow it.
      grid <- 40
      divider <- max(ret$period) %/% grid
      if (divider >= 2) {
        ret <- ret %>% dplyr::mutate(period = period %/% divider * divider) %>%
          dplyr::group_by(variable, value, chart_type, value_index, period) %>%
          dplyr::summarize(survival=max(survival)) %>%
          dplyr::ungroup()
      }

      ret <- ret %>% dplyr::mutate(variable = forcats::fct_relevel(variable, !!x$imp_vars)) # set factor level order so that charts appear in order of importance.
      # set order to ret and turn it back to character, so that the order is kept when groups are bound.
      # if it were kept as factor, when groups are bound, only the factor order from the first group would be respected.
      ret <- ret %>% dplyr::arrange(variable) %>% dplyr::mutate(variable = as.character(variable))
      ret <- ret %>% dplyr::mutate(chart_type = 'line')
      ret <- ret %>% dplyr::mutate(variable = x$terms_mapping[variable]) # map variable names to original.
      ret
    },
    partial_dependence = {
      ret <- x$partial_dependence
      pred_survival_time <- x$pred_survival_time
      ret <- ret %>%
        dplyr::filter(period <= !!pred_survival_time) %>% # Extract the latest period that does not exceed pred_survival_time
        dplyr::group_by(variable, value) %>% dplyr::filter(period == max(period)) %>% dplyr::ungroup() %>%
        dplyr::mutate(type='Prediction')
      actual <- x$survival_curves %>%
        dplyr::filter(period <= !!pred_survival_time) %>% # Extract the latest period that does not exceed pred_survival_time
        dplyr::group_by(variable, value) %>% dplyr::filter(period == max(period)) %>% dplyr::ungroup() %>%
        dplyr::mutate(type='Actual')
      ret <- actual %>% dplyr::bind_rows(ret) # actual rows need to come first for the order of chart drawing.
      ret <- ret %>% dplyr::mutate(variable = forcats::fct_relevel(variable, !!x$imp_vars)) # set factor level order so that charts appear in order of importance.
      # set order to ret and turn it back to character, so that the order is kept when groups are bound.
      # if it were kept as factor, when groups are bound, only the factor order from the first group would be respected.
      ret <- ret %>% dplyr::arrange(variable) %>% dplyr::mutate(variable = as.character(variable))
      ret <- ret %>% survival_pdp_sort_categorical()
      ret <- ret %>% dplyr::mutate(variable = x$terms_mapping[variable]) # map variable names to original.
      ret
    })
}

glance.ranger_survival_exploratory <- function(x, data_type = "training", ...) {
  if (data_type == "training") {
    if (!is.null(x$concordance)) {
      tibble::tibble(Concordance=x$concordance$concordance, `Std Error Concordance`=sqrt(x$concordance$var), `Time-dependent AUC`=x$auc, `Rows`=nrow(x$df), `Rows (TRUE)`=x$nevent)
    }
    else {
      data.frame()
    }
  }
  else { # data_type == "test"
    if (!is.null(x$concordance_test)) {
      tibble::tibble(Concordance=x$concordance_test$concordance, `Std Error Concordance`=sqrt(x$concordance_test$var), `Time-dependent AUC`=x$auc_test, `Rows`=nrow(x$df_test), `Rows (TRUE)`=x$test_nevent)
    }
    else {
      data.frame()
    }
  }
}

# Extracts and interpolates survival rate from survival_mat, for the time specified for each observation by time_vec. 
# survival_mat - Predicted survival curve data obtained from the prediction. 
# time_vec - Time to extract the prediction from the survival_mat for each observation.
# time_index_fun - Function to convert time into the column index + 1 of the survival_mat matrix.
#                  +1 is because we insert a column for time 0 to the matrix to cover the time before the first event. 
survival_time_to_predicted_rate <- function(survival_mat, time_vec, time_index_fun) {
  index_vec <- time_index_fun(time_vec)
  index_vec_floor <- floor(index_vec)
  index_vec_ceiling <- ceiling(index_vec)
  # Will use this index_vec_residual for interpolation between the values for index_vec_floor and index_vec_ceiling.
  index_vec_residual <- index_vec - index_vec_floor
  # 2-column matrix to pass to the survival_mat_adjusted to extract its elements as a vector.
  index_matrix_floor <- cbind(1:length(time_vec), index_vec_floor)
  index_matrix_ceiling <- cbind(1:length(time_vec), index_vec_ceiling)
  # Insert a column for time 0 to cover the time between 0 and the first event.
  survival_mat_adjusted <- cbind(rep(1,nrow(survival_mat)), survival_mat)
  # Extract the values from the matrix as vectors.
  survival_rate_floor <- survival_mat_adjusted[index_matrix_floor]
  survival_rate_ceiling <- survival_mat_adjusted[index_matrix_ceiling]
  # interpolate between survival_rate_floor and survival_rate_ceiling.
  survival_rate <- survival_rate_floor + (survival_rate_ceiling - survival_rate_floor)*index_vec_residual
  survival_rate
}

survival_rate_to_predicted_time <- function(survival_mat, pred_survival_rate, index_time_fun) {
  time_indice <- rep(NA_real_, nrow(survival_mat))
  for (i in 1:nrow(survival_mat)) {
    # Since survival curve monotonously decreases, we can use findInterval that does O(logN) binary search.
    idx <- findInterval(-pred_survival_rate, -survival_mat[i,])
    # Interpolate if possible.
    # When pred_survival_rate is larger than the first survival rate that appear in survivL_mat[i,].
    if (idx == 0) {
      denom <- survival_mat[i,idx+1] - 1.0
      if (!is.na(denom) && denom != 0) {
        idx <- idx + (pred_survival_rate - 1.0)/denom
      }
    }
    # General case
    else if (!is.na(idx) && idx < ncol(survival_mat)) {
      denom <- survival_mat[i,idx+1] - survival_mat[i,idx]
      if (!is.na(denom) && denom != 0) {
        idx <- idx + (pred_survival_rate - survival_mat[i,idx])/denom
      }
    }
    time_indice[i] <- idx
  }
  res <- index_time_fun(time_indice)
  res
}

#' @export
augment.ranger_survival_exploratory <- function(x, newdata = NULL, data_type = "training",
                                                base_time = NULL, base_time_type = "max", pred_time = NULL, # For point-in-time-based survival rate prediction. base_time_type can be "value", "max", or "today".
                                                pred_survival_rate = NULL, # For survival-rate-based event time prediction.
                                                pred_survival_time = NULL, pred_survival_threshold = NULL, ...) { # For survival-time-based survival rate prediction.
  # For predict() to find the prediction method, ranger needs to be loaded beforehand.
  # This becomes necessary when the model was restored from rds, and model building has not been done in the R session yet.
  loadNamespace("ranger")
  if ("error" %in% class(x)) {
    ret <- data.frame(Note = x$message)
    return(ret)
  }
  if(!is.null(newdata)) {
    # Replay the mutations on predictors.
    if(!is.null(x$predictor_funs)) {
      newdata <- newdata %>% mutate_predictors(x$orig_predictor_cols, x$predictor_funs)
    }

    predictor_variables <- attr(x$df,"predictors")
    # If start time column is in newdata, use it.
    if (!is.null(x$clean_start_time_col) && x$terms_mapping[x$clean_start_time_col] %in% colnames(newdata)) {
      predictor_variables <- c(predictor_variables, x$clean_start_time_col)
    }
    predictor_variables_orig <- x$terms_mapping[predictor_variables]

    # Rename columns via predictor_variables_orig, which is a named vector.
    #TODO: What if names of the other columns conflicts with our temporary name, c1_, c2_...?
    cleaned_data <- newdata %>% dplyr::rename(predictor_variables_orig)

    # Align factor levels including Others and (Missing) to the model. TODO: factor level order can be different from the model training data. Is this ok?
    cleaned_data <- align_predictor_factor_levels(cleaned_data, x$xlevels, predictor_variables)

    na_row_numbers <- ranger.find_na(predictor_variables, cleaned_data)
    if (length(na_row_numbers) > 0) {
      cleaned_data <- cleaned_data[-na_row_numbers,]
    }
    data <- cleaned_data
    pred <- predict(x, data=data)
  }
  else if (data_type == "training") {
    data <- x$df
    pred <- x$prediction_training
  }
  else { # data_type == "test"
    if (!is.null(x$df_test)) {
      data <- x$df_test
      pred <- x$prediction_test
    }
    else {
      return(data.frame())
    }
  }

  if (is.null(pred_survival_time)) {
    pred_survival_time <- x$pred_survival_time
  }
  if (is.null(pred_survival_threshold)) {
    pred_survival_threshold <- x$pred_survival_threshold
  }

  # Predict survival probability on a specific date.
  # or predict the day that the survival rate drops to the specified value (pred_survival_rate).
  # Used for prediction step based on the model from the Analytics View.
  if (!is.null(pred_time) || !is.null(pred_survival_rate)) {
    # Common logic between point-of-time-based survival rate prediction and rate-based event time prediction.
    time_unit_days <- get_time_unit_days(x$time_unit)
    # Function to convert time to index. Extended to cover 0.
    time_index_fun <- approxfun(c(0, x$unique.death.times), 1:(length(x$unique.death.times)+1))
    # Predict survival probability on the specified date (pred_time).
    if (!is.null(pred_time)) {
      if (base_time_type == "max") {
        base_time <- as.Date(max(cleaned_data[[x$clean_start_time_col]])) # as.Date is to take care of POSIXct column.
      }
      else if (base_time_type == "today") {
        base_time <- lubridate::today()
      } # if base_time_type is "value", use the argument value as is.
      # For casting the time for prediction to an integer days, use ceil to compensate that we ceil in the preprocessing.
      pred_time <- base_time + lubridate::days(ceiling(pred_time * time_unit_days));
      data <- data %>% dplyr::mutate(base_survival_time = as.numeric(!!base_time - as.Date(!!rlang::sym(x$clean_start_time_col)), units = "days")/time_unit_days)
      # as.Date is to handle the case where the start time column is in POSIXct.
      data <- data %>% dplyr::mutate(prediction_survival_time = as.numeric(pred_time - as.Date(!!rlang::sym(x$clean_start_time_col)), units = "days")/time_unit_days)
      base_survival_rate <- survival_time_to_predicted_rate(pred$survival, data$base_survival_time, time_index_fun)
      target_survival_rate <- survival_time_to_predicted_rate(pred$survival, data$prediction_survival_time, time_index_fun)
      data$predicted_survival_rate <- target_survival_rate / base_survival_rate
      # NA means that the specified time is not covered by the predicted survival curve. 
      data <- data %>% dplyr::mutate(note = if_else(is.na(predicted_survival_rate), "Out of range of the predicted survival curve.", NA_character_))
      data <- data %>% dplyr::mutate(base_time = !!base_time)
      data <- data %>% dplyr::mutate(prediction_time = !!pred_time)
    }
    # Predict the day that the survival rate drops to the specified value. (pred_survival_rate should be there.)
    else {
      data <- data %>% dplyr::mutate(survival_rate_for_prediction = !!pred_survival_rate)
      # Predicted survival period
      index_time_fun <- approxfun(0:length(x$unique.death.times), c(0, x$unique.death.times))
      data$predicted_survival_time <- survival_rate_to_predicted_time(pred$survival, pred_survival_rate, index_time_fun)
      # NA means that the specified survival rate is not covered by the predicted survival curve. 
      data <- data %>% dplyr::mutate(note = if_else(is.na(predicted_survival_time), "Didn't meet the threshold.", NA_character_))
      # For casting the survival time to an integer days, use floor to compensate that we ceil in the preprocessing.
      data <- data %>% dplyr::mutate(predicted_event_time = as.Date(!!rlang::sym(x$clean_start_time_col)) + lubridate::days(floor(predicted_survival_time*time_unit_days)))
    }
  }
  # Predict survival probability on the specified duration (pred_survival_time). Still used in the Analytics View itself.
  else {
    unique_death_times <- x$forest$unique.death.times
    survival_time <- max(unique_death_times[unique_death_times <= pred_survival_time])
    survival_time_index <- match(survival_time, unique_death_times)
    predicted_survival_rate <- pred$survival[,survival_time_index]
    predicted_survival <- predicted_survival_rate > pred_survival_threshold
    data$prediction_survival_time <- pred_survival_time
    data$predicted_survival_rate <- predicted_survival_rate
    data$predicted_survival <- predicted_survival
  }
  ret <- data
  # Move those columns as the last columns.
  ret <- ret %>% dplyr::relocate(any_of(c("base_time", "base_survival_time", "prediction_time", "prediction_survival_time",
                                          "predicted_survival_rate", "predicted_survival",
                                          "survival_rate_for_prediction", "predicted_survival_time", "predicted_event_time", "note")), .after=last_col())
  colnames(ret)[colnames(ret) == "predicted_survival"] <- "Predicted Survival"
  colnames(ret)[colnames(ret) == "base_survival_time"] <- "Base Survival Time"
  colnames(ret)[colnames(ret) == "prediction_survival_time"] <- "Prediction Survival Time"
  colnames(ret)[colnames(ret) == "base_time"] <- "Base Time"
  colnames(ret)[colnames(ret) == "prediction_time"] <- "Prediction Time"
  colnames(ret)[colnames(ret) == "predicted_survival_rate"] <- "Predicted Survival Rate"
  colnames(ret)[colnames(ret) == "survival_rate_for_prediction"] <- "Survival Rate for Prediction"
  colnames(ret)[colnames(ret) == "predicted_survival_time"] <- "Predicted Survival Time"
  colnames(ret)[colnames(ret) == "predicted_event_time"] <- "Predicted Event Time"
  colnames(ret)[colnames(ret) == "note"] <- "Note"

  # Convert column names back to the original.
  for (i in 1:length(x$terms_mapping)) {
    converted <- names(x$terms_mapping)[i]
    original <- x$terms_mapping[i]
    colnames(ret)[colnames(ret) == converted] <- original
  }

  colnames(ret)[colnames(ret) == ".time"] <- "Survival Time" # .time is a column name generated by our Command Generator.
  ret
}
exploratory-io/exploratory_func documentation built on April 23, 2024, 9:15 p.m.