R/plot.R

Defines functions plot_trajectory plot.brokenstick

Documented in plot.brokenstick plot_trajectory

#' Plot observed and fitted trajectories by group
#'
#' The `plot` method for a `brokenstick` object plots the observed and
#' fitted trajectories of one or more groups.
#'
#' By default, `plot(fit)` will plot the observed and fitted data for the
#' first three groups in the data. The default setting drops the fitted value
#' at the right boundary knot from the display.
#'
#' @param x An object of class `brokenstick`.
#' @param ... Extra arguments passed down to [predict.brokenstick()]
#' and [plot_trajectory()].
#' @inheritParams predict.brokenstick
#' @return An object of class [ggplot2::ggplot].
#' @author Stef van Buuren 2023
#' @method plot brokenstick
#' @seealso [predict.brokenstick], [plot_trajectory].
#' @examples
#' \dontrun{
#' # fit model on raw hgt with knots at 0, 1, 2 and 3 years
#' fit1 <- brokenstick(hgt ~ age | id, smocc_200, knots = 0:2)
#' gp <- c(10001, 10005, 10022)
#' plot(fit1, group = gp, xlab = "Age (years)", ylab = "Length (cm)")
#'
#' # fit model on standard deviation score
#' fit2 <- brokenstick(hgt_z ~ age | id, smocc_200, knots = 0:2)
#' plot(fit2, group = gp, xlab = "Age (years)", ylab = "Length (SDS)")
#'
#' # built-in model with 11 knots
#' plot(fit_200, group = gp, xlab = "Age (years)", ylab = "Length (SDS)")
#'
#' # black and white version
#' plot(fit_200, group = gp, xlab = "Age (years)", ylab = "Length (SDS)",
#'     color_y = rep("black", 2), shape_y = 1, linetype_y = 3,
#'     color_yhat = rep("grey20", 2), shape_yhat = NA)
#' }
#' @export
plot.brokenstick <- function(x,
                             newdata = NULL,
                             ...) {
  install.on.demand("ggplot2", ...)
  newdata <- get_newdata(x, newdata)
  nms <- unname(unlist(x$names))
  if (!all(nms %in% colnames(newdata))) {
    stop("Variable(s) not found: ",
         paste(nms[!nms %in% colnames(newdata)], collapse = ", "),
         call. = FALSE
    )
  }

  g <- plot_trajectory(x = x, newdata = newdata, ...)
  return(g)
}


#' Plot observed and fitted trajectories from fitted brokenstick model
#'
#' @inheritParams predict.brokenstick
#' @param x An object of class `brokenstick`.
#' @param newdata A `data.frame` or `matrix`
#' @param hide Character indicating which knots should
#' @param .x The `x` argument of the [predict.brokenstick()] function.
#' @param color_y A character vector with two elements specifying the symbol and line color of the measured data points
#' @param size_y Dot size of measured data points
#' @param linetype_y Line type of data points
#' @param shape_y Symbol for data points
#' @param color_yhat A character vector with two elements specifying the symbol and line color of the predicted data points
#' @param size_yhat Dot size of predicted data points
#' @param linetype_yhat Line type of predicted data
#' @param shape_yhat Symbol for predicted data
#' @param color_imp A character vector with two elements specifying the symbol and line color of the imputed data
#' @param size_imp Dot size of imputed data
#' @param ncol Number of columns in plot
#' @param xlab The label of the x-axis
#' @param ylab The label of the y-axis
#' @param xlim Vector of length 2 with range of x-axis
#' @param ylim Vector of length 2 with range of y-axis
#' @param show A logical vector of length 3. Element 1 specifies
#' whether the observed data are plotted, element 2 specifies
#' whether the broken stick are plotted, element 3 specifies
#' whether imputations are plotted. The default is
#' `c(TRUE, TRUE, FALSE)`.
#' @param n_plot A integer indicating the number of individual plots.
#' The default is 3, which plots the trajectories of the first three
#' groups. The `n_plot` is a safety measure to prevent unintended
#' plots of the entire data set.
#' @param scales Axis scaling, e.g. `"fixed"`, `"free"`, and so on
#' @param theme Plotting theme
#' @param whatknots Deprecated.
#' @param \dots Extra arguments passed down to [predict.brokenstick()].
#' @return An object of class \code{ggplot}
#' @rdname plot_trajectory
#' @seealso [plot.brokenstick]
#' @export
plot_trajectory <- function(x,
                            newdata = NULL,
                            hide = c("right", "left", "boundary", "internal", "none"),
                            .x = NULL,
                            group = NULL,
                            color_y = c(
                              grDevices::hcl(240, 100, 40, 0.7),
                              grDevices::hcl(240, 100, 40, 0.8)
                            ),
                            size_y = 2,
                            linetype_y = 1,
                            shape_y = 19,
                            color_yhat = c(
                              grDevices::hcl(0, 100, 40, 0.7),
                              grDevices::hcl(0, 100, 40, 0.8)
                            ),
                            size_yhat = 2,
                            linetype_yhat = 1,
                            shape_yhat = 19,
                            color_imp = c("grey80", "grey80"),
                            size_imp = 2,
                            ncol = 3L,
                            xlab = NULL,
                            ylab = NULL,
                            xlim = NULL,
                            ylim = NULL,
                            show = c(TRUE, TRUE, FALSE),
                            n_plot = 3L,
                            scales = "fixed",
                            theme = ggplot2::theme_light(),
                            whatknots = "droplast",
                            ...) {
  stopifnot(
    inherits(x, "brokenstick"),
    any(show)
  )
  if (!missing(whatknots)) {
    warning("argument 'whatknots' is deprecated; please use 'hide' instead.",
            call. = FALSE)
    x$hide <- switch(whatknots,
                     droplast = "right",
                     dropfirst = "left",
                     internal = "boundary",
                     all = "none",
                     "none")
  }
  if (!missing(hide)) {
    hide <- match.arg(hide)
  } else {
    hide <- ifelse(is.null(x$hide), "right", x$hide)
  }

  newdata <- get_newdata(x, newdata)
  # calculate brokenstick predictions, long format
  if (show[2L] && missing(.x)) .x <- "knots"
  data <- predict(
    object = x, newdata = newdata, hide = hide, ...,
    x = .x, group = group
  )
  if (ncol(data) == 1L) data <- bind_cols(.source = "data", newdata, data)

  # add imputations
  if (show[3L]) {
    if (!length(x$imp)) {
      stop("Cannot find imputations. Use method 'kr' and set 'nimp'.")
    }

    newdata_isna <- is.na(newdata[, x$names$y, drop = TRUE])
    if (sum(newdata_isna) != nrow(x$imp)) {
      stop(
        "Missing data count mismatch: ", sum(newdata_isna),
        " (newdata) versus ", nrow(x$imp), " (x$imp)."
      )
    }

    imp <- x$imp
    colnames(imp) <- paste(".imp", 1:ncol(imp), sep = "_")
    imputed <- newdata %>%
      filter(newdata_isna) %>%
      select(-x$names$y) %>%
      bind_cols(as.data.frame(imp)) %>%
      mutate(.source = "imputed") %>%
      pivot_longer(
        cols = starts_with(".imp_"),
        names_prefix = ".imp_",
        names_to = ".imp",
        names_transform = list(.imp = as.integer),
        values_to = x$names$y
      )
    data <- bind_rows(data, imputed)
  }

  # apply trims
  idx <- rep(TRUE, nrow(data))
  idx <- idx &
    !is.na(data[[x$names$x]]) &
    !(is.na(data[[x$names$y]]) & data[[".source"]] == "data")
  if (!is.null(xlim)) {
    idx <- idx & data[[x$names$x]] >= xlim[1L] & data[[x$names$x]] <= xlim[2L]
  }
  if (!is.null(ylim)) {
    idx <- idx &
      ((data[[".source"]] == "data" &
          data[[x$names$y]] >= ylim[1L] &
          data[[x$names$y]] <= ylim[2L]) |
         (data[[".source"]] == "added" &
            data[[".pred"]] >= ylim[1L] &
            data[[".pred"]] <= ylim[2L]) |
         (data[[".source"]] == "imputed" &
            data[[x$names$y]] >= ylim[1L] &
            data[[x$names$y]] <= ylim[2L]))
  }

  # safety measure, restrict to first n_plot cases if no groups are specified
  if (is.null(group)) {
    ids <- unique(data[[x$names$g]])
    group <- ids[1L:min(n_plot, length(ids))]
  }
  idx <- idx & (data[[x$names$g]] %in% group)

  # process show vector
  if (!show[1L]) idx <- idx & data[[".source"]] != "data"
  if (!show[2L]) idx <- idx & data[[".source"]] != "added"
  if (!show[3L]) idx <- idx & data[[".source"]] != "imputed"

  data <- data[idx, , drop = FALSE]
  # g <- plot_trajectory(x = x, data = data, xlim = xlim, ylim = ylim, ...)

  if (is.null(xlab)) xlab <- x$names$x
  if (is.null(ylab)) ylab <- x$names$y
  g <- ggplot2::ggplot(data, ggplot2::aes_string(x = x$names$x, y = x$names$y)) +
    ggplot2::xlab(xlab) +
    ggplot2::ylab(ylab)

  if (!is.null(xlim)) g <- g + ggplot2::xlim(xlim)
  if (!is.null(ylim)) g <- g + ggplot2::ylim(ylim)

  # add imputed data
  k <- data$.source == "imputed"
  if (any(k)) {
    g <- g +
      ggplot2::geom_line(ggplot2::aes_string(group = ".imp"),
                         data = data[k, ], color = color_imp[2L]
      ) +
      ggplot2::geom_point(ggplot2::aes_string(y = x$names$y),
                          data = data[k, ], color = color_imp[1L],
                          size = size_imp
      )
  }

  # add observed data points and lines
  k <- data$.source == "data"
  if (any(k)) {
    g <- g +
      ggplot2::geom_line(data = data[k, ], color = color_y[2L],
                         linetype = linetype_y) +
      ggplot2::geom_point(data = data[k, ], color = color_y[1L],
                          size = size_y, shape = shape_y)
  }

  # add broken stick points and lines
  k <- data$.source == "added"
  if (any(k)) {
    if (x$degree == 0L) {
      g <- g + ggplot2::geom_step(ggplot2::aes_string(y = ".pred"),
                                  data = data[k, ], color = color_yhat[2L], linetype = linetype_yhat
      )
    } else {
      g <- g +
        ggplot2::geom_line(ggplot2::aes_string(y = ".pred"),
                           data = data[k, ], color = color_yhat[2L], linetype = linetype_yhat
        ) +
        ggplot2::geom_point(ggplot2::aes_string(y = ".pred"),
                            data = data[k, ], color = color_yhat[1L],
                            size = size_yhat, shape = shape_yhat
        )
    }
  }

  # split out according to subjid
  g <- g +
    ggplot2::facet_wrap(as.formula(paste("~", x$names$g)),
                        ncol = ncol,
                        scales = scales
    ) +
    theme

  return(g)
}

Try the brokenstick package in your browser

Any scripts or data that you put into this service are public.

brokenstick documentation built on March 31, 2023, 9:24 p.m.