R/method.R

#' Printing an MTLR object.
#'
#' Print an object created by \code{\link[MTLR]{mtlr}}.
#' @param x an object of class mtlr (result from calling \code{\link[MTLR]{mtlr}}).
#' @param digits The number of digits to print mtlr weights.
#' @param ... for future methods.
#' @return Call, the original call to the mtlr function. Time points, the time points selected by the mtlr model. Weights, the weights of
#' each feature across time -- rows represent each time point and each column corresponds to a feature.
#' @seealso \code{\link[MTLR]{mtlr}}
#' @export
print.mtlr <- function(x, digits = max(options()$digits - 4,3), ...){
  cat("\nCall: ", deparse(x$Call), "\n")
  cat("\nTime points:\n")
  print(x$time_points, digits = digits)
  cat("\n\nWeights:\n" )
  print(x$weight_matrix,digits = digits)
}


#' Predictions for MTLR
#'
#' Compute survival curves and other fitted values for a model generated by \code{\link[MTLR]{mtlr}}.
#'
#' @param object an object of class mtlr, generated by the \link[MTLR]{mtlr}.
#' @param newdata an optional new dataframe for which to perform predictions using MTLR. If left empty, predictions will
#' be performed using the dataset used to generate the original mtlr object -- note that any error calculation on these
#' predictions will be optimistic since this will only be the resubstitution error and not be representative of error on a new test set.
#' @param type the type of prediction desired. Options are the survival curve for the time points selected by mtlr ("survivalcurve"), the
#' survival curve for given times ("prob_times"), the probability of survival at the observations event time ("prob_event"),
#' the mean survival time ("mean_time"), and the median survival time ("median_time").
#'
#' For "survivalcurve" and "prob_times", the first column of the matrix returned will correspond to the time points and all other columns will be
#' the observations survival probability at those associated time points. The index of a (row) observation in newdata will correspond
#' to the \emph{ith + 1} column of the returned matrix.
#'
#' If "prob_event" is chosen the response (event time) is required. For both "prob_event" and "prob_times", if the event time is larger
#' than all of the time points used to build the mtlr model then the last (lowest) probability is used. For example, if the event time is
#' 100 but the largest time point estimated by the mtlr model was 80 then the survival probability at 100 is equal to the survival
#' probability at 80, \emph{i.e. S(100) = S(80)}.
#'
#' For "mean_time", if survival curves do not extend to zero survival
#' probability a linear extension is added (a linear line from (time = 0,probability = 1) to (time = ?, probability =0)). This is the
#' same for "median_time" except the line need only extend to survival probability = 0.5.
#' A mean/median survival time of Inf is returned for survival curves with all survival probabilities of 1.
#' @param add_zero if TRUE, a time point of "0" and a survival probability of "1" will be added to all survival curves. Additionally,
#' if add_zero is TRUE, type = "mean_time" will represent the average survival time overall but if FALSE, then "mean_time" will be reduced
#' by roughly the value of the first time point. However, "median_time" and "prob_event" will be unchanged.
#' @param times For prediction method "prob_times" you may specify the times at which to predict the survival probability for each row in newdata.
#' This values defaults to all unique event times (both censored and uncensored) in the data on which the model was trained.
#' @param ... for future methods.
#' @return The desired prediction type (a matrix or vector of predictions).
#' @note The predictions generated by type = "survivalcurve" can be plotted using \code{\link[MTLR]{plotcurves}} -- packages
#' ggplot2 and reshape2  must be installed to use this function.
#' @seealso \code{\link[MTLR]{mtlr}} \code{\link[MTLR]{plotcurves}}
#' @examples
#' library(survival)
#' mod <- mtlr(Surv(time,status)~., data = lung)
#'
#' #Here our predictions are on the data from which we trained so our results will be optimistic
#' # since they are produced from resubstitution as opposed to some new test set.
#' predict(mod, type = "survivalcurve")
#' predict(mod, type = "prob_event")
#' predict(mod, type = "median_time")
#' predict(mod, type = "mean_time")
#'
#' #Notice the difference of about 59:
#' predict(mod, type = "mean_time", add_zero = FALSE)
#' @export
predict.mtlr <- function(object, newdata, type = c("survivalcurve","prob_times","prob_event","mean_time","median_time"),
                         add_zero = T, times = c(),...){
  type <- match.arg(type)
  if(type == "prob_times" & length(times) < 1){
    times = sort(unique(object$response[,1]))
  }
  if(missing(newdata)){
    newframe <- object$x
  }else{
    Terms <- object$Terms
    Terms <- stats::delete.response(Terms)
    newframe <- stats::model.matrix(Terms, data=newdata,
                             xlev=object$xlevels)
    newframe <- newframe[1:nrow(newframe),-1, drop = FALSE] #Remove intercept term.
    if(!is.null(object$scale)){
      newframe <- scale(newframe, center = object$scale$center, scale = object$scale$sd)
      if(any(object$scale$sd == 0)) #If a variable had 0 variance in training set all the values to 0 since they will be NaN right now.
        newframe[,which(object$scale$sd == 0)] = 0
    }
  }

  surv_probs <- mtlr_predict(c(object$weight_matrix), newframe)
  #Issue due to machine precision, we get survival probabilities of 1+e-16. So here we adjust for that.
  surv_probs[surv_probs > 1] <- 1
  if(add_zero){
      times = c(0, times)
      time_points <- c(0,object$time_points)
      surv_probs <- rbind(1,surv_probs)
  }else{
    time_points <- object$time_points
  }
  surv_curves <- cbind.data.frame(time = time_points,surv_probs)
  switch(type,
         survivalcurve = surv_curves,
         prob_times = {
           surv = cbind.data.frame(time = times, sapply(surv_curves[,-1], function(x) predict_prob(x, time_points, times)))
           #We get row names of 'times' so we will remove those.
           row.names(surv) = NULL
           surv
         },
         prob_event = {
           if(!missing(newdata)){
             mf <- stats::model.frame(object$Terms, newdata)
             y <- stats::model.response(mf)
           }else{
             y <- object$response
           }
           event_times <- y[,1]
           unname(mapply(function(x,y) predict_prob(x, time_points,y), surv_curves[,-1], event_times))
         },
         mean_time = {
           unname(sapply(surv_curves[,-1], function(x) predict_mean(x, time_points)))
         },
         median_time = {
           unname(sapply(surv_curves[,-1],function(x) predict_median(x, time_points)))
         }
  )
}

#' Graphical Representation of Feature Weights
#'
#' Plot the weights of an mtlr object. If packages ggplot2 and reshape2 are
#' not installed, a bargraph of feature \emph{influence} is given where influence is defined as the sum of absolute values of the
#' feature weights across time. If ggplot2 and reshape2 are installed then a plot of feature weight across time is given.
#'
#'@param x an object of class mtlr (result from calling \code{\link[MTLR]{mtlr}}).
#'@param numfeatures the number of weight to plot. Default is 5. The most influential features are chosen first.
#'@param featurenames the names of the specific weight to plot. These should correspond to the names
#'in x$weight_matrix. If featurenames are supplied, then numfeatures is ignored.
#'@param digits the number of digits to round to for the value of the time points.
#'@param ... for future methods
#'@examples
#'#These examples are geared towards users who have installed ggplot2 and reshape2.
#'library(survival)
#'mod <- mtlr(Surv(time,status)~., data = lung)
#'#Basic plot with 5 most influential features
#'plot(mod)
#'#Plot all 8 features
#'plot(mod, numfeatures = 8)
#'#Suppose we want to see specifically the "meal.cal" and "ph.karno" features:
#'plot(mod, featurenames = c("meal.cal", "ph.karno"))
#'@export
plot.mtlr <- function(x, numfeatures=5, featurenames = c(), digits, ...) {
  if (requireNamespace(c("ggplot2","reshape2"), quietly = TRUE)) {
    weights <- x$weight_matrix
    time_points <- x$time_points
    if(!length(featurenames)){
      influence <- get_param_influence(x)
      if(numfeatures > length(influence)){
        warning("Number of features specified greater than the total number of features. This has been replaced with the total number of features.")
        numfeatures <- length(influence)
      }
      top_weights <- influence[order(influence, decreasing = T)[1:numfeatures]]
      top_weight_names <- names(top_weights)
      weight_index <- match(top_weight_names, colnames(weights))
    }else{
      weight_index <- match(featurenames, colnames(weights))
    }
    plot_weights <- weights[,weight_index, drop = FALSE]
    chr_time <- as.character(round(time_points,digits))
    plot_data <- cbind.data.frame(time = chr_time, plot_weights)
    plot_data <- reshape2::melt(plot_data,id.vars = "time")
    ggplot2::ggplot(plot_data, ggplot2::aes(x = plot_data$time,
                                            y = plot_data$value,
                                            group = plot_data$variable,
                                            color = plot_data$variable)) +
      ggplot2::geom_point(size = 2.5) +
      ggplot2::geom_line(size = 1.25) +
      ggplot2::scale_x_discrete("Event Time", limits = chr_time)+
      ggplot2::theme_bw() +
      ggplot2::labs(y = "Weight", color = "Feature") +
      ggplot2::theme(text = ggplot2::element_text(size = 16),
            axis.text = ggplot2::element_text(size = 12, face = "bold"))
  } else {
    warning("Installing ggplot2 and reshape2 will given a more informative plot.")
    influence <- get_param_influence(x)
    graphics::barplot(influence, ylab = "Influence", xlab = "Feature")
  }
}

Try the MTLR package in your browser

Any scripts or data that you put into this service are public.

MTLR documentation built on June 4, 2019, 1:02 a.m.