R/scatter.gam.R

Defines functions scatter.gam

Documented in scatter.gam

#' Scatter Plot with GAM Smooth Line
#'
#' Creates a scatter plot with a GAM (Generalized Additive Model) smooth line.
#' Supports both \code{scatter.gam(x, y)} and \code{scatter.gam(y ~ x)}.
#'
#' @param x A numeric vector of x values, or a formula of the form \code{y ~ x}.
#' @param y A numeric vector of y values. Not used if \code{x} is a formula.
#' @param data.dots Logical. If TRUE, displays data on scatterplot
#' @param three.dots Logical. If TRUE, divides x into tertiles and puts markers on the average  x & y for each
#' @param data An optional data frame containing the variables \code{x} and \code{y}.
#' @param k Optional integer specifying the basis dimension for the smooth term
#'   in the GAM model (passed to \code{s(x, k=k)}). If NULL (default), uses the
#'   default basis dimension.
  #' @param plot.dist Character string specifying how to plot the distribution of \code{x}
  #'   underneath the scatter plot. Options: \code{NULL} (default, auto-select based on
  #'   number of unique values), \code{"none"} (no distribution plot), \code{"plot_freq"}
  #'   (always use \code{plot_freq()}), or \code{"hist"} (always use \code{hist()}).
  #'   When \code{NULL}, uses \code{plot_freq()} if there are 25 or fewer unique values,
  #'   otherwise uses \code{hist()}.
#' @param dot.pch Plotting character for data points when \code{data.dots = TRUE}.
#'   Default is 16 (filled circle).
#' @param dot.col Color for data points when \code{data.dots = TRUE}. Default is
#'   \code{adjustcolor('gray', 0.7)} (semi-transparent gray).
#' @param jitter Logical. If TRUE, applies a small amount of jitter to data points
#'   to reduce overplotting. Default is FALSE.
#' @param ... Additional arguments passed to \code{plot()} and \code{gam()}.
#'   Common plot arguments include:
#'   \itemize{
#'     \item \code{main}: Custom title for the plot (e.g., \code{main = "My Title"})
#'     \item \code{col}: Color of the GAM smooth line (e.g., \code{col = "red"})
#'     \item \code{lwd}: Line width of the GAM smooth line (e.g., \code{lwd = 2})
#'     \item \code{xlim}, \code{ylim}: Axis limits (e.g., \code{xlim = c(0, 10)})
#'     \item \code{xlab}, \code{ylab}: Axis labels (e.g., \code{xlab = "Age"})
#'   }
#'
#' @return Invisibly returns the fitted GAM model object.
#'
#' @details
#' This function fits a GAM model with a smooth term for x and plots the fitted
#' smooth line. The function uses the \code{mgcv} package's \code{gam()} function.
#'
#' When \code{three.dots = TRUE}, the x variable is divided into three equal-sized
#' groups (tertiles), and the mean x and y values for each group are plotted as
#' points. This provides a simple summary of the relationship across the range of x.
#'
#' @examples
#' # Generate sample data for examples
#' x <- rnorm(100)
#' y <- 2*x + rnorm(100)
#'
#' # Plot GAM smooth line only
#' scatter.gam(x, y)
#'
#' # Equivalent call using formula syntax (y ~ x)
#' scatter.gam(y ~ x)
#'
#' # Include scatter plot with underlying data points behind the GAM line
#' scatter.gam(x, y, data.dots = TRUE)
#'
#' # Include summary points showing mean x and y for each tertile bin
#' scatter.gam(x, y, three.dots = TRUE)
#'
#' # Customize the plot with a custom title, line color, and line width
#' scatter.gam(x, y, data.dots = TRUE, col = "red", lwd = 2, main = "GAM Fit")
#'
#' # Control smoothness of the GAM line by specifying the basis dimension
#' scatter.gam(x, y, k = 10)
#'
#' @seealso \code{\link[stats]{scatter.smooth}} for a simpler loess-based scatter plot smoother.
#'
#' @importFrom mgcv gam
#' @export
scatter.gam <- function(x, y, data.dots = TRUE, three.dots = FALSE, data = NULL, k = NULL, plot.dist = NULL, 
                        dot.pch = 16, dot.col = adjustcolor('gray', 0.7), jitter = FALSE, ...) {
  # Check if x is a formula (y ~ x syntax)
  is_formula <- tryCatch(inherits(x, "formula"), error = function(e) FALSE)
  
  if (is_formula) {
    # Formula syntax: y ~ x (supports both bare names and expressions like mpg$hwy ~ mpg$displ)
    formula_vars <- all.vars(x)
    # Environment for evaluation: data first if provided, then formula env
    formula_env <- environment(x)
    if (is.null(formula_env)) {
      formula_env <- parent.frame()
    }
    eval_env <- if (!is.null(data)) {
      if (!is.data.frame(data)) {
        stop("scatter.gam(): 'data' must be a data frame", call. = FALSE)
      }
      list2env(data, parent = formula_env)
    } else {
      formula_env
    }

    if (length(formula_vars) == 2) {
      # Two variable names: use names for labels and extract from data or env
      y_var_name <- formula_vars[1]
      x_var_name <- formula_vars[2]
      if (!is.null(data)) {
        if (!y_var_name %in% names(data)) {
          stop(sprintf("scatter.gam(): Variable \"%s\" not found in dataset", y_var_name), call. = FALSE)
        }
        if (!x_var_name %in% names(data)) {
          stop(sprintf("scatter.gam(): Variable \"%s\" not found in dataset", x_var_name), call. = FALSE)
        }
        y <- data[[y_var_name]]
        x <- data[[x_var_name]]
      } else {
        y_exists <- exists(y_var_name, envir = formula_env, inherits = TRUE)
        x_exists <- exists(x_var_name, envir = formula_env, inherits = TRUE)
        if (!y_exists) {
          stop(sprintf("scatter.gam(): Could not find variable '%s'", y_var_name), call. = FALSE)
        }
        if (!x_exists) {
          stop(sprintf("scatter.gam(): Could not find variable '%s'", x_var_name), call. = FALSE)
        }
        y <- get(y_var_name, envir = formula_env, inherits = TRUE)
        x <- get(x_var_name, envir = formula_env, inherits = TRUE)
      }
      y_name <- y_var_name
      x_name <- x_var_name
    } else if (length(x) >= 3) {
      # LHS ~ RHS with complex expressions (e.g. mpg$hwy ~ mpg$displ): evaluate both sides
      form <- x
      y <- eval(form[[2]], envir = eval_env)
      x <- eval(form[[3]], envir = eval_env)
      if (!is.numeric(y) || !is.numeric(x)) {
        stop("scatter.gam(): Formula sides must evaluate to numeric vectors", call. = FALSE)
      }
      # Use variable names only (strip data frame prefix e.g. mtcars$hp -> hp)
      y_name <- paste(deparse(form[[2]], width.cutoff = 60L), collapse = " ")
      x_name <- paste(deparse(form[[3]], width.cutoff = 60L), collapse = " ")
      y_name <- sub("^.*\\$", "", trimws(y_name))
      x_name <- sub("^.*\\$", "", trimws(x_name))
    } else {
      stop("scatter.gam(): Formula must have exactly two terms: y ~ x", call. = FALSE)
    }
  } else {
    # Standard syntax: x, y
    # Capture x and y names for labels (before potentially overwriting)
    x_name_raw <- deparse(substitute(x))
    y_name_raw <- deparse(substitute(y))
    # Remove quotes if present
    x_name_raw <- gsub('^"|"$', '', x_name_raw)
    y_name_raw <- gsub('^"|"$', '', y_name_raw)
    # Clean variable names: remove df$ prefix if present
    x_name <- if (grepl("\\$", x_name_raw)) {
      strsplit(x_name_raw, "\\$")[[1]][length(strsplit(x_name_raw, "\\$")[[1]])]
    } else {
      x_name_raw
    }
    y_name <- if (grepl("\\$", y_name_raw)) {
      strsplit(y_name_raw, "\\$")[[1]][length(strsplit(y_name_raw, "\\$")[[1]])]
    } else {
      y_name_raw
    }
    
    # Handle data frame if provided
    if (!is.null(data)) {
      if (!is.data.frame(data)) {
        stop("scatter.gam(): 'data' must be a data frame", call. = FALSE)
      }
      
      # Extract columns from data frame
      # Use raw names for column lookup (they may include df$ prefix)
      if (!x_name_raw %in% names(data)) {
        stop(sprintf("scatter.gam(): Column '%s' not found in data", x_name_raw), call. = FALSE)
      }
      if (!y_name_raw %in% names(data)) {
        stop(sprintf("scatter.gam(): Column '%s' not found in data", y_name_raw), call. = FALSE)
      }
      
      x <- data[[x_name_raw]]
      y <- data[[y_name_raw]]
    }
  }
  
  # Extract additional arguments
  dots <- list(...)
  
  # Extract plot.dist from dots if it was passed via ... (for backward compatibility)
  if ("plot.dist" %in% names(dots)) {
    plot.dist <- dots$plot.dist
    dots$plot.dist <- NULL
  }
  
  # Check for required package
  if (!requireNamespace("mgcv", quietly = TRUE)) {
    stop("Package 'mgcv' is required. Please install it with: install.packages('mgcv')")
  }
  
  # Fit GAM model
  # Define GAM argument names (to avoid duplication)
  gam_arg_names <- c("family", "method", "optimizer", "control", 
                     "scale", "select", "knots", "sp", "min.sp",
                     "H", "gamma", "fit", "paraPen", "G", "in.out",
                     "drop.unused.levels", "drop.intercept", "nthreads",
                     "cluster", "mustart", "etastart", "offset", "subset",
                     "na.action", "start", "model", "x", "y")
  
  # Separate gam arguments from plot arguments
  gam_args <- dots[names(dots) %in% gam_arg_names]
  plot_args <- dots[!names(dots) %in% gam_arg_names]
  
  # Build GAM formula with optional k parameter
  if (!is.null(k)) {
    gam_formula <- y ~ s(x, k = k)
  } else {
    gam_formula <- y ~ s(x)
  }
  g1 <- do.call(mgcv::gam, c(list(formula = gam_formula), gam_args))
  
  # Create grid for smooth line prediction
  # Use a single shared xlim for both panels to keep x-axes aligned
  # Prefer user-provided xlim if available, otherwise use data range (no buffer)
  xlim_common <- if ("xlim" %in% names(plot_args)) plot_args$xlim else range(x, na.rm = TRUE)
  newdat <- data.frame(x = seq(xlim_common[1], xlim_common[2], length.out = 400))
  yh <- predict(g1, newdata = newdat)
  
  # Calculate three-way spline points if requested
  if (three.dots == TRUE) {
    xq <- cut(x, 3)
    # aggregate returns data frame with Group.1 (factor levels) and x (aggregated values)
    x3_agg <- aggregate(x, list(xq), mean, na.rm = TRUE)
    y3_agg <- aggregate(y, list(xq), mean, na.rm = TRUE)
    x3_means <- x3_agg$x  # Mean x values for each tertile
    y3_means <- y3_agg$x  # Mean y values for each tertile
  }
  
  # Determine ylim (user ylim from ... overrides computed value to avoid duplicate-argument error)
  if ("ylim" %in% names(plot_args)) {
    ylim <- plot_args$ylim
    plot_args$ylim <- NULL
  } else if (data.dots == TRUE && three.dots == TRUE) {
    ylim <- range(c(y, yh, y3_means), na.rm = TRUE)
  } else if (data.dots == TRUE) {
    ylim <- range(y, na.rm = TRUE)
  } else if (three.dots == TRUE) {
    ylim <- range(c(yh, y3_means), na.rm = TRUE)
  } else {
    ylim <- range(yh, na.rm = TRUE)
  }
  
  # Determine if we need to plot distribution and which method to use
  plot_distribution <- is.null(plot.dist) || (plot.dist != "none")
  use_plot_freq <- FALSE
  use_hist <- FALSE
  
  if (plot_distribution) {
    if (!is.null(plot.dist) && plot.dist == "plot_freq") {
      use_plot_freq <- TRUE
    } else if (!is.null(plot.dist) && plot.dist == "hist") {
      use_hist <- TRUE
    } else {
      # Auto-select (when plot.dist is NULL): use plot_freq if 25 or fewer unique values
      n_unique <- length(unique(x))
      if (n_unique <= 25) {
        use_plot_freq <- TRUE
      } else {
        use_hist <- TRUE
      }
    }
  }
  
  # Save current par settings and restore on exit
  old_par <- par(no.readonly = TRUE)
  on.exit(par(old_par), add = TRUE)
  
  # Set up layout for two panels if distribution plotting is requested
  if (plot_distribution) {
    # Set up two panels: top for scatter plot, bottom for distribution
    # Use layout to control spacing (no gap between panels)
    layout(matrix(c(1, 2), nrow = 2, ncol = 1), heights = c(2, 1))
    
    # Set margins: remove bottom margin from top plot, remove top margin from bottom plot
    # Use same left and right margins for both plots to ensure x-axes align
    # Left margin: 5.1 to accommodate y-axis label on top plot
    # Right margin: 4.1 to accommodate y-axis labels on bottom plot (axis 4)
    par(mar = c(0, 5.1, 3, 4.1))  # Top plot: no bottom margin, 3 lines top margin, aligned margins
  } else {
    # Set top margin to 3 lines when not plotting distribution
    par(mar = c(old_par$mar[1], old_par$mar[2], 3, old_par$mar[4]))
  }
  
  # Set default labels if not provided
  if (!"xlab" %in% names(plot_args)) plot_args$xlab <- x_name
  if (!"ylab" %in% names(plot_args)) plot_args$ylab <- y_name
  
  # Set axis label formatting and orientation
  if (!"font.lab" %in% names(plot_args)) plot_args$font.lab <- 2
  if (!"cex.lab" %in% names(plot_args)) plot_args$cex.lab <- 1.2
  if (!"las" %in% names(plot_args)) plot_args$las <- 1
  
  # Suppress x-axis on top plot if distribution is plotted below
  if (plot_distribution) {
    plot_args$xaxt <- "n"
  }
  
  # Extract main title if provided, use default if not
  user_main <- plot_args$main
  plot_args$main <- NULL  # Remove main from plot_args to avoid double title
  
  # Create plot frame first (without drawing the line yet)
  # We'll draw the line after the data points so it appears on top
  plot_args_frame <- c(list(x = newdat$x, y = yh, type = 'n', ylim = ylim, xlim = xlim_common), 
                       plot_args)
  do.call(plot, plot_args_frame)
  
  # Add title on top - use user-provided main if specified, otherwise default
  if (!is.null(user_main)) {
    mtext(side = 3, text = user_main, line = 1, font = 2, cex = 1.2)
  } else {
    mtext(side = 3, text = paste0("Scatter GAM - ", x_name, " & ", y_name), 
          line = 1, font = 2, cex = 1.2)
  }
  
  # Add data points if requested (draw these first, before the line)
  if (data.dots == TRUE) {
    # Detect dense data and adjust visualization accordingly
    n_points <- length(x)
    n_unique_x <- length(unique(x))
    density_ratio <- n_points / n_unique_x
    
    # Auto-adjust transparency and jitter for dense data
    auto_jitter <- jitter
    auto_dot_col <- dot.col
    auto_cex <- 0.8  # Default point size
    
    # If data is dense (many points per unique x value), use more transparency and auto-jitter
    if (density_ratio > 10 && !jitter) {
      auto_jitter <- TRUE
      # Make points more transparent for dense data
      if (identical(dot.col, adjustcolor('gray', 0.7))) {
        # Only adjust if using default color
        auto_dot_col <- adjustcolor('gray', 0.3)  # Much more transparent
        auto_cex <- 0.5  # Smaller points for dense data
      }
    } else if (density_ratio > 5 && !jitter) {
      # Moderate density - still use jitter
      auto_jitter <- TRUE
      if (identical(dot.col, adjustcolor('gray', 0.7))) {
        auto_dot_col <- adjustcolor('gray', 0.4)
        auto_cex <- 0.6
      }
    }
    
    # Apply jitter if requested or auto-enabled
    if (auto_jitter) {
      # Calculate jitter amount based on data range
      x_range <- diff(range(x, na.rm = TRUE))
      y_range <- diff(range(y, na.rm = TRUE))
      x_jitter <- x_range * 0.01  # 1% of x range
      y_jitter <- y_range * 0.01  # 1% of y range
      x_plot <- x + runif(length(x), -x_jitter, x_jitter)
      y_plot <- y + runif(length(y), -y_jitter, y_jitter)
    } else {
      x_plot <- x
      y_plot <- y
    }
    points(x_plot, y_plot, pch = dot.pch, col = auto_dot_col, cex = auto_cex)
  }
  
  # Add three-way spline points if requested
  if (three.dots == TRUE) {
    points(x3_means, y3_means, pch = 16, cex = 1.3)
  }
  
  # Draw smooth line on top of data points
  lines(newdat$x, yh, col = 'blue', lwd = 3)
  
  # Plot distribution in bottom panel if requested
  if (plot_distribution) {
    # Switch to bottom panel
    # Use same left and right margins as top plot to ensure x-axes align
    par(mar = c(5.1, 5.1, 0, 4.1))  # Bottom plot: no top margin, aligned margins with top plot
    
    # Determine xlim for distribution plot
    # Use the same xlim as the scatter plot to keep both panels aligned
    xlim_dist <- xlim_common
    
    # Prepare distribution plot arguments - only pass essential arguments
    # plot_freq() and hist() have their own defaults, so we only override what's necessary
    dist_plot_args <- list(
      xlab = x_name,
      ylab = "",  # Will use mtext instead
      main = "",
      xlim = xlim_dist
    )
    
    # Optionally pass through some formatting arguments if they were specified
      # but let plot_freq/hist use their defaults otherwise
    if ("font.lab" %in% names(plot_args)) dist_plot_args$font.lab <- plot_args$font.lab
    if ("cex.lab" %in% names(plot_args)) dist_plot_args$cex.lab <- plot_args$cex.lab
    if ("las" %in% names(plot_args)) dist_plot_args$las <- plot_args$las
    
    if (use_plot_freq) {
      # Calculate ylim from frequencies for background plot (10% buffer at top)
      freq_table <- table(x)
      max_freq <- max(freq_table, na.rm = TRUE)
      ylim_freq <- c(0, if (max_freq > 0) max_freq * 1.10 else 0)
      
      # Get cex.lab for consistency
      cex_lab <- if ("cex.lab" %in% names(plot_args)) plot_args$cex.lab else 1.2
      
      # Initialize bottom plot with gray80 background
      init_bottom_plot(xlim = xlim_dist,
                       ylim = ylim_freq,
                       xlab = x_name,
                       ylab = "",  # Will use mtext instead
                       bg = "gray95",
                       cex.lab = cex_lab)
      
      # Use plot_freq() with .overlay=TRUE to draw on the background
      # Pass x as first positional argument (formula parameter) without naming it
      plot_freq_args <- list(x,
                             xlab = x_name,
                             main = "",
                             xlim = xlim_dist,
                             .overlay = TRUE)
      if ("col" %in% names(dist_plot_args)) plot_freq_args$col <- dist_plot_args$col
      do.call(plot_freq, plot_freq_args)
      
      # Draw axes after plot_freq (since .overlay=TRUE doesn't draw axes)
      axis(1)  # x-axis
      axis(2, las = 1, col = "gray40", col.axis = "gray40")  # y-axis on left side with gray40 color
      
      # Add Frequency label on left side next to the plot with gray40 color
      # Use line = 3 to match the default ylab position from plot() (mgp[1] = 3)
      mtext(side = 2, text = "Frequency", line = 3, font = 2, cex = cex_lab, col = "gray40")
    } else if (use_hist) {
      # Use hist() for distribution plot of x (10% buffer on ylim)
      h <- hist(x, plot = FALSE, breaks = "Sturges")
      max_count <- max(h$counts, na.rm = TRUE)
      ylim_hist <- c(0, if (max_count > 0) max_count * 1.10 else 0)
      hist_args <- c(list(x = x, yaxt = "n", ylim = ylim_hist, breaks = h$breaks), dist_plot_args)
      do.call(hist, hist_args)
      axis(2, las = 1, col = "gray40", col.axis = "gray40")  # y-axis on left side with gray40 color
      
      # Add Frequency label on left side next to the plot with gray40 color
      # Use line = 3 to match the default ylab position from plot() (mgp[1] = 3)
      cex_lab <- if ("cex.lab" %in% names(plot_args)) plot_args$cex.lab else 1.2
      mtext(side = 2, text = "Frequency", line = 3, font = 2, cex = cex_lab, col = "gray40")
    }
  }
  
  # Return GAM model invisibly
  invisible(g1)
}

Try the statuser package in your browser

Any scripts or data that you put into this service are public.

statuser documentation built on April 25, 2026, 5:06 p.m.