R/plot_gam.R

Defines functions plot_gam

Documented in plot_gam

#' Plot GAM Model 
#'
#' Plots fitted GAM values for focal predictor,
#' keeping any other predictors in the model at a specified quantile  (default: median)
#'
#' @param model A GAM model object fitted using \code{mgcv::gam()}.
#' @param predictor Character string specifying the name of the predictor variable
#'   to plot on the x-axis.
#' @param quantile.others Number between 1 and 99 for quantile
#'   at which all other predictors are held constant. Default is 50 (median).
#' @param col Color for the prediction line. Default is "blue4".
#' @param bg Background color for the confidence band. Default is
#'   \code{adjustcolor('dodgerblue', .2)}.
#' @param plot2 How to plot the distribution in the lower plot. Options: \code{'auto'} (default,
#'   auto-select based on number of unique values), \code{'freq'} (always plot frequencies),
#'   \code{'density'} (always plot the density) or \code{'none'} (neither). When \code{'auto'}, 
#'   plots frequencies with predictor has less than 30 unique values, density otherwise.
#' @param col2 Color for the lines/bars in the bottom distribution plot. Default is "dodgerblue"
#' @param bg2 Background color for the bottom distribution plot. Default is "gray90".
#' @param ... Additional arguments passed to \code{plot()} and \code{lines()}.
#'
#' @return Invisibly returns a list containing:
#' \itemize{
#'   \item \code{predictor_values}: The sequence of predictor values used
#'   \item \code{predicted}: The predicted values
#'   \item \code{se}: The standard errors
#'   \item \code{lower}: Lower confidence bound (predicted - 2*se)
#'   \item \code{upper}: Upper confidence bound (predicted + 2*se)
#' }
#'
#' @examples
#' \donttest{
#' library(mgcv)
#' # Fit a GAM model
#' data(mtcars)
#' mtcars$cyl <- factor(mtcars$cyl)  # Convert to factor before fitting GAM
#' model <- gam(mpg ~ s(hp) + s(wt) + cyl, data = mtcars)
#'
#' # Plot effect of hp (with other variables at median)
#' plot_gam(model, "hp")
#'
#' # Plot effect of hp (with other variables at 25th percentile)
#' plot_gam(model, "hp", quantile.others = 25)
#'
#' # Customize plot
#' plot_gam(model, "hp", main = "Effect of Horsepower", col = "blue", lwd = 2)
#' }
#'
#' @importFrom mgcv gam
#' @export
plot_gam <- function(model, predictor, quantile.others = 50, 
                     col = "blue4", bg = adjustcolor('dodgerblue', .2), 
                     plot2 = 'auto', col2 = NULL, bg2 = "gray90", ...) {
  # Check if model is from mgcv::gam()
  if (!inherits(model, "gam")) {
    stop("'model' must be a GAM model object fitted using mgcv::gam()")
  }
  
  # Check if mgcv is available
  if (!requireNamespace("mgcv", quietly = TRUE)) {
    stop("Package 'mgcv' is required. Please install it with: install.packages('mgcv')")
  }
  
  # Check if factor() is used in the model formula
  # This causes issues with predict() - variables should be converted to factor before gam()
  model_formula <- formula(model)
  formula_char <- paste(deparse(model_formula), collapse = " ")
  # Check for factor( in the formula (factor is lowercase in R)
  if (grepl("\\bfactor\\s*\\(", formula_char, ignore.case = FALSE)) {
    message2("plot_gam() says: A variable in the GAM formula was included with 'factor()'.\n",
         "Please, instead, convert that variable to factor  before running the GAM model\n",
         "\nExample: Instead of gam(y ~ factor(x), data = df), do:\n",
         "    df$x <- factor(df$x)\n",
         "    gam(y ~ x, data = df)", col = 'red', stop = TRUE)
  }
  
  # Validate quantile.others
  if (!is.numeric(quantile.others) || length(quantile.others) != 1) {
    stop("'quantile.others' must be a single numeric value")
  }
  if (quantile.others < 1 || quantile.others > 99) {
    stop("'quantile.others' must be between 1 and 99")
  }
  
  # Validate predictor
  if (!is.character(predictor) || length(predictor) != 1) {
    stop("'predictor' must be a single character string")
  }
  
  # Validate plot2
  if (!is.character(plot2) || length(plot2) != 1) {
    stop("'plot2' must be a single character string")
  }
  if (!plot2 %in% c('auto', 'freq', 'density', 'none')) {
    stop("'plot2' must be one of: 'auto', 'freq', 'density', 'none'")
  }
  
  # Extract model data
  # Try model$model first, then model.frame() as fallback
  model_data <- model$model
  if (is.null(model_data)) {
    # Try to get model frame from the model
    tryCatch({
      model_data <- model.frame(model)
    }, error = function(e) {
      stop("Model data not found. Please refit the model with 'keepData = TRUE' or ensure model data is available.")
    })
  }
  
  # Get the original data if available (from model call)
  # This helps ensure we have the right variable structure
  original_data <- NULL
  if (!is.null(model$call$data)) {
    tryCatch({
      original_data <- eval(model$call$data, envir = environment(formula(model)))
    }, error = function(e) {
      # If we can't get original data, use model_data
      original_data <- model_data
    })
  } else {
    original_data <- model_data
  }
  
  # Get all variable names from the model frame
  all_vars <- names(model_data)
  
  # Remove response variable (first column is typically the response)
  response_var <- all_vars[1]
  predictor_vars <- all_vars[-1]
  
  # Check if predictor exists in model data
  if (!predictor %in% predictor_vars) {
    stop(sprintf("Predictor '%s' not found in model variables. Available variables: %s",
                 predictor, paste(predictor_vars, collapse = ", ")))
  }
  
  # Get other variables (all predictors except the one we're plotting)
  other_vars <- setdiff(predictor_vars, predictor)
  
  # Create new dataset for prediction
  # Start with the predictor variable: 100 equally spaced values between min and max
  predictor_values <- model_data[[predictor]]
  predictor_min <- min(predictor_values, na.rm = TRUE)
  predictor_max <- max(predictor_values, na.rm = TRUE)
  predictor_seq <- seq(predictor_min, predictor_max, length.out = 100)
  n_rows <- length(predictor_seq)
  
  # CRITICAL: Replicate the entire model frame structure first
  # This preserves ALL attributes, factor structures, and internal mgcv requirements
  new_data <- model_data[rep(1, n_rows), , drop = FALSE]
  
  # Now modify the predictor variable
  new_data[[predictor]] <- predictor_seq
  
  # Set other variables to their quantile/default values
  for (var in other_vars) {
    # Get variable from model_data (this is what mgcv sees)
    var_values <- model_data[[var]]
    
    # Handle different variable types
    if (is.factor(var_values)) {
      # For factors, use the first/lowest factor level
      first_level <- levels(var_values)[1]
      # Create factor with first level, preserving all levels
      new_data[[var]] <- factor(rep(first_level, n_rows), levels = levels(var_values))
    } else if (is.character(var_values)) {
      # For character vectors, use the most common value
      var_levels <- table(var_values)
      most_common <- names(var_levels)[which.max(var_levels)]
      new_data[[var]] <- rep(most_common, n_rows)
    } else if (is.numeric(var_values)) {
      # For numeric variables, use the specified quantile
      quantile_value <- quantile(var_values, probs = quantile.others / 100, na.rm = TRUE)
      new_data[[var]] <- rep(as.numeric(quantile_value), n_rows)
    } else {
      # For other types, try to use the most common value
      warning(sprintf("Variable '%s' has unsupported type. Using most common value.", var))
      var_levels <- table(var_values)
      most_common <- names(var_levels)[which.max(var_levels)]
      new_data[[var]] <- rep(most_common, n_rows)
    }
  }
  
  # Extract only predictor variables (exclude response) for prediction
  new_data <- new_data[, predictor_vars, drop = FALSE]
  
  # Make predictions with standard errors
  pred_result <- predict(model, newdata = new_data, se.fit = TRUE)
  predicted <- pred_result$fit
  se <- pred_result$se.fit
  
  # Calculate confidence bounds (2 standard errors)
  lower <- predicted - 2 * se
  upper <- predicted + 2 * se
  
  # Extract additional arguments for plotting
  dots <- list(...)
  
  # Remove col and bg from dots if present (use formal arguments instead)
  if ("col" %in% names(dots)) {
    col <- dots$col
    dots$col <- NULL
  }
  if ("bg" %in% names(dots)) {
    bg <- dots$bg
    dots$bg <- NULL
  }
  
  # Set default labels if not provided
  dots <- set_default(dots, "xlab", predictor)
  dots <- set_default(dots, "ylab", response_var)
  
  # Set default main title if not provided
  main_title_text <- if ("main" %in% names(dots)) dots$main else paste0("GAM Predicting '", response_var, "' with '", predictor,"'")
  
  # Extract model formula for subtitle (will be added after plot)
  model_formula_text <- paste(deparse(formula(model)), collapse = " ")
  
  # Set main='' so we can use mtext() for both titles
  dots$main <- ""
  
  # Set default formatting
  dots <- set_default(dots, "font.lab", 2)
  dots <- set_default(dots, "cex.lab", 1.2)
  dots <- set_default(dots, "las", 1)
  
  # Set default main title formatting if not provided
  dots <- set_default(dots, "font.main", 2)  # bold
  dots <- set_default(dots, "cex.main", 1.2)
  
  # Determine ylim if not provided
  if (!"ylim" %in% names(dots)) {
    dots$ylim <- range(c(lower, upper), na.rm = TRUE)
  }
  
  # Set default xlim if not provided
  if (!"xlim" %in% names(dots)) {
    dots$xlim <- range(predictor_seq, na.rm = TRUE)
  }
  
  # Determine if we need to plot distribution and which method to use
  plot_distribution <- plot2 != 'none'
  use_plot_freq <- FALSE
  use_plot_density <- FALSE
  
  # Extract predictor data once if needed for distribution plot
  predictor_data <- NULL
  if (plot_distribution) {
    predictor_data <- model_data[[predictor]]
    n_unique <- length(unique(predictor_data))
    
    if (plot2 == 'freq') {
      use_plot_freq <- TRUE
    } else if (plot2 == 'density') {
      use_plot_density <- TRUE
    } else if (plot2 == 'auto') {
      # Auto-select: use plot_freq if <= 30 unique values, else plot_density
      if (n_unique <= 30) {
        use_plot_freq <- TRUE
      } else {
        use_plot_density <- TRUE
      }
    }
  }
  
  # Save/restore only what we change in par(); restoring full `par()` can reset `mfg`
  # and break multi-panel plotting via `par(mfrow=...)` in the caller.
  old_mar <- par("mar")
  old_mfrow <- par("mfrow")
  old_mfcol <- par("mfcol")
  on.exit({
    par(mar = old_mar)
    # plot_gam() may use layout() internally; clear it so caller's mfrow can work again
    layout(1)
    par(mfrow = old_mfrow, mfcol = old_mfcol)
  }, add = TRUE)
  
  # Set up layout for two panels if distribution plotting is requested
  if (plot_distribution) {
    # Set up two panels: top for GAM plot (75%), bottom for distribution (25%)
    # Use layout to control spacing (no gap between panels)
    layout(matrix(c(1, 2), nrow = 2, ncol = 1), heights = c(0.75, 0.25))
    
    # Set margins: remove bottom margin from top plot, remove top margin from bottom plot
    # Top margin needs to accommodate main title and subtitle
    par(mar = c(0, 4.1, 5.1, 2.1))  # Top plot: no bottom margin, extra top margin for title
  } else {
    # When no bottom plot, ensure we have enough top margin for title and subtitle
    if (old_mar[3] < 5) {
      par(mar = c(old_mar[1], old_mar[2], 5.1, old_mar[4]))
    }
  }
  
  # Suppress x-axis on top plot if distribution is plotted below
  if (plot_distribution) {
    dots$xaxt <- "n"
  }
  
  # Plot predicted values
  plot_args <- c(list(x = predictor_seq, y = predicted, type = "l"), dots)
  do.call(plot, plot_args)
  
  # Add main title and subtitle using mtext()
  # Get main title cex to calculate subtitle size
  main_cex <- if ("cex.main" %in% names(dots)) dots$cex.main else 1.2
  main_font <- if ("font.main" %in% names(dots)) dots$font.main else 2
  subtitle_cex <- main_cex * 0.9  # 10% smaller (90% of main size)
  
  # Add main title at line=2
  mtext(main_title_text, 
        side = 3, 
        line = 2, 
        cex = main_cex, 
        font = main_font)
  
  # Add subtitle with model formula at line=1 (below main title)
  mtext(model_formula_text, 
        side = 3, 
        line = 1, 
        cex = subtitle_cex, 
        font = 3,  # italic
        col = "gray70")
  
  # Add confidence bands
  polygon(c(predictor_seq, rev(predictor_seq)), 
          c(lower, rev(upper)),
          col = bg,
          border = NA)
  
  # Redraw the prediction line on top
  lines(predictor_seq, predicted, col = col, ...)
  
  # Plot distribution in bottom panel if requested
  if (plot_distribution) {
    # Switch to bottom panel
    par(mar = c(5.1, 4.1, 0, 2.1))  # Bottom plot: no top margin
    
    # Determine xlim for distribution plot (use same as main plot)
    if ("xlim" %in% names(dots)) {
      xlim_dist <- dots$xlim
    } else {
      xlim_dist <- range(predictor_seq, na.rm = TRUE)
    }
    
    if (use_plot_freq) {
      # For plot_freq, determine ylim from frequencies
      freq_table <- table(predictor_data)
      max_freq <- max(freq_table, na.rm = TRUE)
      ylim_dist <- c(0, max_freq)
      # Add extra space at top if value labels might be shown (plot_freq default behavior)
      if (max_freq > 0) {
        ylim_dist[2] <- max_freq + max(1, max_freq * 0.15)  # Add 15% or at least 1 unit
      }
      
      # Get ylab cex from GAM plot to match size
      ylab_cex <- if ("cex.lab" %in% names(dots)) dots$cex.lab else 1.2
      
      # Initialize bottom plot with background
      init_bottom_plot(xlim = xlim_dist, 
                       ylim = ylim_dist,
                       xlab = predictor, 
                       ylab = "",  # Will use mtext instead
                       bg = bg2, 
                       cex.lab = ylab_cex)
      
      # Use plot_freq() with .overlay=TRUE to draw on the background
      # Set col2 if provided, otherwise use plot_freq default
      plot_freq_args <- list(formula = predictor_data, 
                             xlab = predictor,  # Predictor variable name
                             main = "",
                             xlim = xlim_dist,
                             .overlay = TRUE)
      if (!is.null(col2)) {
        plot_freq_args$col <- col2
      }
      do.call(plot_freq, plot_freq_args)
      
      # Draw axes after plot_freq (since .overlay=TRUE doesn't draw axes)
      axis(1)  # x-axis
      axis(2, las = 1)  # y-axis on left side
      
      # Add Frequency label on left side, aligned with top plot's ylab and in gray like tick numbers
      mtext(side = 2, text = "Frequency", line = 2.5, font = 2, cex = ylab_cex, col = "gray30")
    } else if (use_plot_density) {
      # For plot_density, compute density manually to determine ylim
      density_obj <- density(predictor_data)
      ylim_density <- c(0, max(density_obj$y) * 1.3)
      
      # Get ylab cex from GAM plot to match size
      ylab_cex <- if ("cex.lab" %in% names(dots)) dots$cex.lab else 1.2
      
      # Initialize bottom plot with background
      init_bottom_plot(xlim = xlim_dist, 
                       ylim = ylim_density,
                       xlab = predictor, 
                       ylab = "",  # Will use mtext instead
                       bg = bg2, 
                       cex.lab = ylab_cex)
      
      # Manually plot the density curve
      # Use col2 if provided, otherwise default to "dodgerblue"
      density_col <- if (!is.null(col2)) col2 else "dodgerblue"
      lines(density_obj, col = density_col, lwd = 4)
      
      # Add axes
      axis(1)  # x-axis
      axis(4, las = 1)  # y-axis on right side to avoid conflict with GAM plot
      
      # Add Density label on left side, aligned with top plot's ylab and in gray like tick numbers
      mtext(side = 2, text = "Density", line = 2.5, font = 2, cex = ylab_cex, col = "gray30")
    }
  }
  
  # Return results invisibly
  result <- list(
    predictor_values = predictor_seq,
    predicted = predicted,
    se = se,
    lower = lower,
    upper = upper
  )
  
  invisible(result)
}

Try the statuser package in your browser

Any scripts or data that you put into this service are public.

statuser documentation built on April 25, 2026, 5:06 p.m.