R/predict.R

Defines functions predictSurvProb.deeppam predict_cumu_hazards.deeppam predict_cumu_hazards predict_hazards.deeppam predict_hazards predict.deeppam

Documented in predict.deeppam predictSurvProb.deeppam

#' Predicted values based on deeppam object.
#' @param object keras model of object class "deeppam".
#' @param x non-optimal data for predicitons. Must be the same format as
#' for fitting.
#' @param output a character indicicating how the output should be shaped.
#' "tensor" is the standard output, having dimensions nxtx1,
#' "matrix" is reduced to nXt,
#' "numeric2 is reduced to n*t.
#' @param internal a logical indicating if the predicition is internal (used for e.g.
#' computing validation loss) or external (e.g. for evaluation).
#' This argument has effect on wheter the offset is considered when predicting.
#' Thus, the default is TRUE so that during fitting this can be accessed.
#' @param ... further arguments passed to or from other methods.
#' @return
#' @export
#' @author Philipp Kopper
#' @rdname predict.deeppam
#' @import dplyr
#' @import mgcv
predict.deeppam <-
  function(object, x = NULL, output = c("tensor", "matrix", "numeric"),
           internal = TRUE, ...) {
    if (is.null(x)) {
      stop("You need to supply the data for which you want predictions.")
    }
    output <- match.arg(output, c("tensor", "matrix", "numeric"))
    #check output
    class(object) <- class(object)[class(object) != "deeppam"]
    if (!internal) {
      if (!is.list(x)) {
        x[, , 1] <- 0
      } else {
        x[[1]][, , 1] <- 0
      }
    }
    predicted_tensor <- predict(object, x = x, ...)
    if (length(x) > 1) {
      dnames <- dimnames(x[[1]])[1:2]
    } else {
      dnames <- dimnames(x[[1:2]])
    }
    dimnames(predicted_tensor) <- dnames
    if (output == "tensor") {
      predicted_tensor
    } else if (output == "matrix") {
      predicted_tensor[1:dim(predicted_tensor)[1], 1:dim(predicted_tensor)[2], ]
    } else {
      predicted_tensor[1:dim(predicted_tensor)[1],
                       1:dim(predicted_tensor)[2], ] %>% t() %>% as.numeric()
    }
  }

predict_hazards <- function(object, ...) {
  UseMethod("predict_hazards")
}

predict_hazards.deeppam <-
  function(object, output = c("tensor", "matrix", "vector"), ...) {
    predict.deeppam(object, output, internal = FALSE, ...)
  }

# tbd
predict_cumu_hazards <- function(object, ...) {
  UseMethod("predict_hazards")
}

# tbd
predict_cumu_hazards.deeppam <-
  function(object, output = c("tensor", "matrix", "vector"), ...) {
    hazards <- predict.deeppam(object, output = "matrix", ...)
    stop("predict_cumu_hazards not finished yet.")
  }


#' Predicting Survival Probabilites
#'
#' Function to extract survival probability predictions from deeppam.
#' This function works only with pec::pec.
#' @param object a keras model (deeppam object).
#' @param newdata A data frame containing predictor variable combinations for
#' which to compute predicted survival probabilities.
#' @param times A vector of times in the range of the response variable, e.g.
#' times when the response is a survival object, at which to return the
#' survival probabilities.
#' @return A matrix with as many rows as NROW(newdata) and as many columns as
#' length(times). Each entry should be a probability and in rows the values
#' should be decreasing.
#' @author Philipp Kopper, Andreas Bender
#' @import dplyr
#' @import mgcv
#' @importFrom purrr map
predictSurvProb.deeppam <- function(
  object,
  newdata,
  times) {

  deeppamdata <- try(attr(object, "test_data"), silent = TRUE)
  if (inherits(deeppamdata, "try-error")) {
    stop("You need to add the preprocessed (i.e. tensor format) test data (i.e. of <newdata>) via add_test_data() to the DeepPAM.")
  }
  if (!is.array(deeppamdata)) {
    ped <- deeppamdata[[1]]
    id_var     <- attr(ped, "trafo_args")[["id"]]
    id_order_newdata <- newdata[[id_var]]
    current_order <- dimnames(ped)[[1]]
    neworder <- match(id_order_newdata, current_order)
    deeppamdata[[1]][, , 1] <- 0
    deeppamdata[[1]] <- deeppamdata[[1]][neworder, , ]
    for (i in 2:length(deeppamdata)) {
      dims <- length(dim(deeppamdata[[i]]))
      if (dims == 2) {
        deeppamdata[[i]] <- deeppamdata[[i]][neworder, , , drop = FALSE]
      } else if (dims == 3) {
        deeppamdata[[i]] <- deeppamdata[[i]][neworder, , , drop = FALSE]
      } else if (dims == 4) {
        deeppamdata[[i]] <- deeppamdata[[i]][neworder, , , , drop = FALSE]
      } else if (dims == 5) {
        deeppamdata[[i]] <- deeppamdata[[i]][neworder, , , , , drop = FALSE]
      } else {
        stop("Only max. 5 dimensional input testable yet.")
      }
    }
  } else {
    ped <- deeppamdata
    id_var     <- attr(ped, "trafo_args")[["id"]]
    id_order_newdata <- newdata[[id_var]]
    current_order <- dimnames(ped)[[1]]
    neworder <- match(id_order_newdata, current_order)
    deeppamdata[, , 1] <- 0
    deeppamdata <- deeppamdata[neworder, , ]
  }
  raw        <- attr(ped, "raw")
  trafo_args <- attr(ped, "trafo_args")
  brks       <- trafo_args[["cut"]]
  if ( max(times) > max(brks) ) {
    stop("Cannot predict beyond the last time point used during model estimation.
        Check the 'times' argument.")
  }

  ped_times <- sort(unique(union(c(0, brks), times)))
  # extract relevant intervals only, keeps data small
  ped_times <- ped_times[ped_times <= max(times)]
  # obtain interval information
  ped_info <- get_intervals(brks, ped_times[-1])
  # add adjusted offset such that cumulative hazard and survival probability
  # can be calculated correctly
  ped_info[["intlen"]] <- c(ped_info[["times"]][1], diff(ped_info[["times"]]))
  # create data set with interval/time + covariate info
  predicted_hazards <- predict_hazards(object, x = deeppamdata, output = "numeric")
  n <- nrow(newdata)
  tend <- as.numeric(unique(raw$tend))
  # this obviously not correct fragment is necessary due to an odd sorting
  # behaviour of the pec package
  pred_frame <- data.frame(id = rep(1:n, each = length(tend)),
                           tend = rep(tend, n), pred = predicted_hazards)
  newdata <- combine_df(ped_info, newdata)
  env_times <- times
  newdata <-
    dplyr::inner_join(newdata, pred_frame, by = c("id" = "id", "tend" = "tend"))
  newdata[["intlen"]] <- recompute_intlen(newdata)
  newdata <- newdata %>%
    arrange(.data$id, .data$times) %>%
    group_by(.data$id) %>%
    mutate(pred = exp(-cumsum(.data$pred * .data$intlen))) %>%
    ungroup() %>%
    filter(.data[["times"]] %in% env_times)
  id <- unique(newdata[[id_var]])
  pred_list <- map(
    id,
    ~ newdata[newdata[[id_var]] == .x, "pred"] %>% pull("pred"))
  res <- do.call(rbind, pred_list)
  res
}
pkopper/deeppam documentation built on Jan. 19, 2021, 12:39 a.m.