R/plot.csvy.R

Defines functions plot.csvy plot_csvy_control

Documented in plot.csvy plot_csvy_control

#' Control settings for plot.csvy
#'
#' Creates a list of graphical options to customize plots generated by \code{\link{plot.csvy}}.
#' This includes labels, text sizes, colors, shapes, themes, and other display features.
#' @param x1lab Character. Label for the first covariate (x-axis). Default is \code{NULL}.
#' @param x1_labels Logical or Character vector. Custom tick labels for the first covariate. Default is \code{TRUE}. If it is TRUE, tick labels will be created; if it is a character vector, then this vector will be used as the tick labels.
#' @param x1size Numeric. Font size for annotation labels on the x1 axis. Default is \code{3.8}.
#' @param x2lab Character. Label for the second covariate (x-axis). Default is \code{NULL}.
#' @param x2_labels Logical or Character vector. Custom tick labels for the second covariate. Default is \code{TRUE}. Default is \code{TRUE}. If it is TRUE, tick labels will be created; if it is a character vector, then this vector will be used as the tick labels.
#' @param x3lab Character. Label for the third covariate, if used (for subtitles or grouping). Default is \code{NULL}.
#' @param ynm Character. Label for the response. Default is \code{NULL}.
#' @param ci Logical. If TRUE, confidence bands are displayed. Defaults to TRUE.
#' @param legend Logical. If TRUE, legend for constrained fit or unconstrained fit will be shown. Defaults to TRUE.
#' @param ylab Logical. If TRUE, the response name will be shown on the y-axis. Defaults to TRUE.
#' @param x3_labels Logical or Character vector. Custom labels for the third covariate. Default is \code{TRUE}. If it is TRUE, labels will be created; if it is a character vector, then this vector will be used as the labels.
#' @param x4_vals Character vector. For models with more than three predictors, specifies the category to use for each additional predictor. Defaults to NULL, using the mode of each.
#' @param x4_labels Character vector. Custom labels for the fourth covariate. Default is \code{NULL}.
#' @param x2size Numeric. Font size for annotation labels on the x2 axis. Default is \code{3.8}.
#' @param constrained_color Character. Color used to display fitted values and intervals from the constrained model. Default is \code{"cornflowerblue"}.
#' @param unconstrained_color Character. Color used to display fitted values and intervals from the unconstrained model. Default is \code{"#A3C99A"}.
#' @param constrained_shape Integer. Shape code (used by \code{ggplot2}) for points corresponding to constrained fits. Default is \code{16} (solid circle).
#' @param unconstrained_shape Integer. Shape code for points from unconstrained fits. Default is \code{18} (solid diamond).
#' @param ribbon_fill Character. Fill color for the confidence ribbon around the fitted lines. Default is \code{"lightblue"}.
#' @param line_color Character. Color of the lines connecting the fitted values. Default is \code{"black"}.
#' @param base_theme A \code{ggplot2} theme object used as the base plot theme. Default is \code{ggplot2::theme_minimal()}.
#' @param subtitle.size Numeric. Font size for the subtitle text in the plot. Default is \code{12}.
#' @param angle Numeric. Angle (in degrees) to rotate x-axis labels (typically for x1). Default is \code{0}.
#' @param hjust Numeric. Horizontal justification for rotated x-axis labels. Default is \code{.1} (right-aligned).
#' @return A named list of graphical control parameters to be passed to the \code{control} argument in \code{\link{plot.csvy}}.
#' @examples
#' plot_csvy_control(
#'   x1lab = "Age Group", 
#'   x2lab = "Region", 
#'   constrained_color = "cornflowerblue", 
#'   unconstrained_color = "gray80",
#'   x1size = 4.5
#' )
#'
#' @export
plot_csvy_control <- function(x1lab = NULL, x1_labels = TRUE, x2lab = NULL, x2_labels = TRUE, 
                              x3lab=NULL, x3_labels = TRUE, x4_vals = NULL, x4_labels = NULL, ynm = NULL, ci = TRUE,
                              legend = TRUE, ylab = TRUE,
                              x1size = 3.8, x2size = 3.8, 
                              constrained_color = "cornflowerblue",
                              unconstrained_color = "#A3C99A", constrained_shape = 16, 
                              unconstrained_shape = 18, ribbon_fill = "lightblue", line_color = "black", 
                              base_theme = ggplot2::theme_minimal(), subtitle.size = 12, angle = 0, hjust = .1)
{
  stopifnot(is.numeric(x1size), is.numeric(x2size))
  list(
    x1lab = x1lab,
    x1_labels = x1_labels,
    x2lab = x2lab,
    x2_labels = x2_labels,
    x3lab = x3lab, 
    x3_labels = x3_labels,
    x4_vals = x4_vals,
    x4_labels = x4_labels,
    ynm = ynm,
    ci = ci,
    legend = legend,
    ylab = ylab,
    x1size = x1size,
    x2size = x2size,
    constrained_color = constrained_color,
    unconstrained_color = unconstrained_color,
    constrained_shape = constrained_shape,
    unconstrained_shape = unconstrained_shape,
    ribbon_fill = ribbon_fill, 
    line_color = line_color,
    base_theme = base_theme,
    subtitle.size = subtitle.size,
    angle = angle,
    hjust = hjust
  )
}

#' Plot method for csvy objects
#' Generates a diagnostic or summary plot from a fitted \code{"csvy"} object.
#' Supports both single-factor and two-factor visualization. 
#' Aesthetic settings can be customized using \code{plot_csvy_control()}.
#' @param x An object of class \code{"csvy"}.
#' @param x1 Optional. Name of the first factor to display in two-factor plots. Defaults to the first added variable.
#' @param x2 Optional. Name of the second factor to display in two-factor plots. Defaults to the second added variable.
#' @param domains Optional. A data frame containing some domain(s) to be emphasized on the plot. Defaults to be NULL.
#' @param type Character string, either \code{"constrained"}, \code{"unconstrained"}, or \code{"both"}. Defaults to \code{"constrained"}.
#' @param control A list of display options returned by \code{\link{plot_csvy_control}}. Defaults to \code{plot_csvy_control()}.
#' @param ... Additional arguments passed to \code{ggplot2::geom_line()} or \code{geom_point()}, such as \code{linewidth}, \code{size}, etc.
#' @return A \code{ggplot2} object.
#' @seealso \code{\link{plot_csvy_control}} for a full list of customizable settings.
#' @examples
#' # plot.csvy(fit)
#' # plot.csvy(fit, x1 = "education", x2 = "region", control = plot_csvy_control(x1lab = "Education"))
#' @export
plot.csvy <- function(x, x1 = NULL, x2 = NULL, domains = NULL, 
                      type = c("constrained", "unconstrained", "both"), 
                      control = plot_csvy_control(),...){
  object <- x
  #avoid defining a lot of NULL
  defaults <- plot_csvy_control()
  control <- modifyList(defaults, control)
  extras <- list(...)
  #unpack the control list
  x1lab <- control$x1lab
  x1_labels <- control$x1_labels
  x2lab <- control$x2lab
  x2_labels <- control$x2_labels
  x3lab <- control$x3lab
  x3_labels <- control$x3_labels
  x1size <- control$x1size
  x2size <- control$x2size
  x4_vals <- control$x4_vals
  x4_labels <- control$x4_labels 
  constrained_color <- control$constrained_color
  unconstrained_color <- control$unconstrained_color
  constrained_shape <- control$constrained_shape
  unconstrained_shape <- control$unconstrained_shape
  base_theme <- control$base_theme
  subtitle.size <- control$subtitle.size
  ribbon_fill <- control$ribbon_fill
  line_color <- control$line_color
  angle <- control$angle
  hjust <- control$hjust
  ci <- control$ci
  legend <- control$legend
  ylab <- control$ylab
  #ynm <- control$ynm
  
  #get 
  xmat <- object$xm.red
  xnm <- object$xnms_add[1]
  ynm <- control$ynm %||% object$ynm
  Ds <- object$Ds
  x1_labels_0 <- x1_labels
  x2_labels_0 <- x2_labels
  linkinv <- object$family$linkinv
  #is_constrained <- type == "constrained"
  y <- object$y
  
  if (!inherits(object, "csvy")) {
    #warning("The object is not of class 'csurvey'.")
    stop("Invalid object passed. This function requires a 'csurvey' object.")
  }
  
  if(length(Ds) == 1){
    x_raw <- xmat[[1]]
    # Convert character to factor and preserve order
    if (is.character(x_raw)) {
      x_raw <- factor(x_raw, levels = unique(x_raw))
    }
    if(is.factor(x_raw)){
      x_labels <- levels(x_raw)             # character labels for axis
      x_numeric <- as.numeric(x_raw)        # numeric positions 1, 2, ...
    } else {
      x_labels <- x_raw            
      x_numeric <- x_raw     
    }
    #test more
    ord <- order(x_numeric)
    x_numeric <- sort(x_numeric)
    x_labels <- x_labels[ord]
    n <- length(x_numeric)
    #labels <- c("constrained", "unconstrained")
    
    #df_long <- tibble(
    #  x = rep(x_numeric, 2),
    #  fit_type = rep(labels, each = n),
    #  muhat = c(linkinv(object$etahat), linkinv(object$etahatu)), #c(object$muhat, object$muhatu),
    #  lwr = c(linkinv(object$lwr), linkinv(object$lwru)),
    #  upp = c(linkinv(object$upp), linkinv(object$uppu))
    #)
    
    #type <- match.arg(type, choices = c("both", "constrained", "unconstrained"))
    type <- match.arg(type)
    df_list <- list()
    labels <- c()
    if (type %in% c("constrained", "both")) {
      df_con <- tibble(
        x = x_numeric,
        fit_type = "constrained",
        muhat = linkinv(object$etahat),
        nd = object$nd, 
        lwr = if (!is.null(object$lwr)) linkinv(object$lwr) else NA_real_,
        upp = if (!is.null(object$upp)) linkinv(object$upp) else NA_real_
      )
      df_list <- append(df_list, list(df_con))
      labels <- c(labels, "constrained")
    }
    if (type %in% c("unconstrained", "both")) {
      df_unc <- tibble(
        x = x_numeric,
        fit_type = "unconstrained",
        muhat = linkinv(object$etahatu),
        nd = object$nd, 
        lwr = if (!is.null(object$lwru)) linkinv(object$lwru) else NA_real_,
        upp = if (!is.null(object$uppu)) linkinv(object$uppu) else NA_real_
      )
      df_list <- append(df_list, list(df_unc))
      labels <- c(labels, "unconstrained")
    }
    df_long <- bind_rows(df_list)
    
    color_map <- c("constrained" = constrained_color, "unconstrained" = unconstrained_color)
    shape_map <- c("constrained" = constrained_shape, "unconstrained" = unconstrained_shape)
    
    #test more!
    if(is.numeric(x_labels)){
      #brks = x_labels[1]:rev(x_labels)[1]
      #test more:
      brks = x_labels
      if(isTRUE(object$replaced[[xnm]])){
        x_labels <- object$levels_old[[1]]
      } 
    } else {
      brks = 1:length(x_labels)
    }
    
    ylab_text = if (object$family$family %in% c("binomial", "quasibinomial")) "Predicted probability of" else "average response"
    p <- ggplot(df_long, aes(x = x, y = muhat, color = fit_type, shape = fit_type)) +
      geom_ribbon(aes(ymin = lwr, ymax = upp, fill = fit_type), alpha = 0.2, color = NA) +
      #geom_ribbon(data = df_long |> filter(!is.na(lwr) & !is.na(upp)),
      #            aes(ymin = lwr, ymax = upp, fill = fit_type), alpha = 0.2, color = NA) +
      geom_line(linewidth = 0.8,...) +
      geom_point(size = 2,...) +
      scale_color_manual(values = color_map, labels = labels, name = NULL) +
      scale_fill_manual(values = color_map, labels = labels, name = NULL) +
      scale_shape_manual(values = shape_map, labels = labels, name = NULL) +
      #test more!
      scale_x_continuous(
        breaks = brks,
        labels = x_labels,
        minor_breaks = NULL
      ) +
      scale_y_continuous(expand = expansion(mult = c(0.05, 0.05))) +
      labs(
        x = if (!is.null(xnm)) xnm else "x",
        y = if (isTRUE(ylab)) {
          if (!is.null(ynm)) paste(ylab_text, ynm) else ylab_text
        } else {
          NULL
        }
      ) +
      theme_minimal(base_size = 13) +
      theme(
        legend.position = if (isTRUE(legend)) "top" else "none",
        axis.text.x = element_text(angle = angle, hjust = hjust, size = x1size)  # add a angle for xticks
      )
    
    #test more: mark NA's or empty domains even when they are imputed
    if(anyNA(df_long[, c("muhat")]) || any(df_long[, "nd"] == 0)){
      y_min <- min(df_long$muhat, df_long$lwr, df_long$upp, na.rm = TRUE)
      y_min_offset <- y_min - 0.01 * diff(range(c(df_long$muhat, df_long$lwr, df_long$upp), na.rm = TRUE))
      p <- p + geom_point(
        data = df_long |> filter(is.na(muhat) | is.na(lwr) | is.na(upp) | nd == 0),
        aes(x = x, y = y_min_offset), shape = 4, size = 1.5, stroke = 1, color = "red", inherit.aes = FALSE)
    }
    
    if(!is.null(domains)){
      if(NCOL(domains) > 1){
        ps0 <- sapply(colnames(domains), function(elem) which(sapply(colnames(object$grid), function(gi) all(gi == elem))))
        #order the column in case the user doesn't define the variables as the order used in the formula
        domains <- domains[, as.numeric(ps0)]
        #ps = apply(newdata, 1, function(elem) which(apply(grid2, 1, function(gi) all(gi == elem))))
      }
      ps <- match_rows(domains, object$grid)
      
      na_rows <- which(is.na(ps))
      if (all(is.na(ps))) {
        stop("None of the domains in 'domains' are defined in the model.")
      } else if (length(na_rows) > 0) {
        removed_rows <- domains[na_rows, , drop = FALSE]
        warning("The following rows were removed because their domains are not defined in the model:\n",
                paste(capture.output(print(removed_rows)), collapse = "\n"))
        domains <- domains[-na_rows, , drop = FALSE]
        ps <- ps[-na_rows]
      }
      
      #print (ps)
      new_points <- data.frame(
        x = df_long$x[ps], #position that new data points match ans$grid
        y = df_long$muhat[ps],
        lwr = df_long$lwr[ps],
        upp = df_long$upp[ps]
      )
      #print (new_points)
      p <- p + geom_pointrange(
        data = new_points,
        aes(x = x, y = y, ymin = lwr, ymax = upp),
        color = "red",
        size = 0.8,
        inherit.aes = FALSE
      )
    }
    p
  } else {
    # Support both quoted and unquoted input
    #ynm <- object$ynm
    ynm <- control$ynm %||% object$ynm
    xnms_add <- object$xnms_add
    grid <- object$grid
    
    # Default to first two added variables if not provided
    if (is.null(x1)) x1 <- object$xnms_add[1]
    if (is.null(x2)) x2 <- object$xnms_add[2]
    
    x1nm <- if (is.character(x1)) x1 else deparse(substitute(x1))
    x2nm <- if (is.character(x2)) x2 else deparse(substitute(x2))
    
    # Default labels (can be overridden by xlab/ylab)
    x1lab <- x1lab %||% x1nm
    x2lab <- x2lab %||% x2nm
    
    other_xnms <- setdiff(xnms_add, c(x1nm, x2nm))
    
    # Validate that x1nm and x2nm exist
    if (!(x1nm %in% names(object$Ds))) {
      stop(paste("Variable", x1nm, "not found in object$Ds"))
    }
    if (!(x2nm %in% names(object$Ds))) {
      stop(paste("Variable", x2nm, "not found in object$Ds"))
    }
    
    # if(exists("subtitle.size", where = extras)){
    #   subtitle.size <- extras$subtitle.size
    # }else{
    #   subtitle.size <- 12
    # }
    
    keep_id <- 1:NROW(grid)
    title_text <- NULL
    if(length(xnms_add) > 3){
      if(exists("x3", where = extras)){
        x3 <- extras$x3
        x3nm <- if (is.character(x3)) x3 else deparse(substitute(x3))
      } else {
        x3nm <- other_xnms[1]
      }
      xnm_list <- setdiff(xnms_add, c(x1nm, x2nm, x3nm))
      # Compute the mode for each variable (breaking ties by first encountered)
      mode_vals <- sapply(xnm_list, function(x) {
        tab <- table(grid[[x]])
        names(tab)[which.max(tab)]
      })
      
      #new: let the user to choose a val for x4's
      warning_msgs <- character(0)
      x4_vals <- x4_vals %||% mode_vals
    
      # Check x4_vals and fix invalid ones
      for(j in seq_along(xnm_list)){
        x4_nm <- xnm_list[j]
        #if (!is.null(x4_vals[j])) {
        allowed_vals <- object$xvs2[[x4_nm]]
        if (!(x4_vals[j] %in% allowed_vals)) {
          warning_msgs <- c(warning_msgs, sprintf("x4_vals[%d] = '%s' is not valid for '%s'. Replacing with mode '%s'.", j, x4_vals[j], x4_nm, mode_vals[j]))
          x4_vals[j] <- mode_vals[j]  # Replace with mode
        }
        # } else {
        #   x4_vals[j] <- mode_vals[j]  # Fill NULL with mode
        # }
        x4_labels <- x4_labels %||% x4_vals
        title_text_i <- paste(x4_nm, " = ", x4_labels[which(allowed_vals %in% x4_vals)])
        title_text <- paste(title_text, title_text_i, sep = " ")
      }
      
      # Emit all warnings together, if any
      if (length(warning_msgs) > 0) {
        warning(paste(warning_msgs, collapse = "\n"))
      }
      
      # Filter the rows where each variable equals its mode
      keep_rows <- Reduce(`&`, Map(function(x, val) grid[[x]] == val, xnm_list, x4_vals))
      keep_id <- which(keep_rows)
      grid <- grid[keep_id, ,drop = FALSE]
      
      #title_text <- paste(title_text, "(n = ", nrow(grid), ")")
    } else if (length(xnms_add) == 3) {
      x3nm <- other_xnms
    } else {
      x3nm <- NULL
    }
    
    D1 <- object$Ds[x1nm]
    D2 <- object$Ds[x2nm]
    M <- NROW(grid)
    #M <- NROW(object$grid)
    
    linkinv <- object$family$linkinv
    #is_constrained <- type == "constrained"
    
    # Define plot aesthetics depending on 'type'
    #ribbon_fill <- if (type == "unconstrained") "darkgreen" else "grey80"
    #line_color <- if (type == "unconstrained") "darkgreen" else "black"
    
    #new: 
    type <- match.arg(type)
    #type <- match.arg(type, choices = c("constrained", "unconstrained", "both"))
    df_list <- list()
    labels <- c()
    if (type %in% c("constrained", "both")) {
      out <- reorder_outputs(object, x1nm, x2nm, x3nm, grid=grid, keep_id=keep_id, is_constrained=TRUE, ci=ci, type=type)
      df_con <- tibble(
        domain = 1:M,
        fit_type = "constrained",
        muhat = out$muhat, 
        lwr = out$lwr, 
        upp = out$upp,
        nd = out$nd, 
        block = if (M == (D1*D2)) rep(1:D2, each = D1) else rep(1:(D1 * D2), each = M / (D1 * D2)) 
      )
      df_con[[x1lab]] <- out$grid[[x1nm]]
      df_con[[x2lab]] <- out$grid[[x2nm]]
      df_list <- append(df_list, list(df_con))
      labels <- c(labels, "constrained")
    }
    
    if (type %in% c("unconstrained", "both")) {
      out <- reorder_outputs(object, x1nm, x2nm, x3nm, grid=grid, keep_id=keep_id, is_constrained=FALSE, ci=ci, type=type)
      df_unc <- tibble(
        domain = 1:M,
        fit_type = "unconstrained",
        muhat = out$muhat, 
        lwr = out$lwr, 
        upp = out$upp,
        nd = out$nd, 
        block = if (M == (D1*D2)) rep(1:D2, each = D1) else rep(1:(D1 * D2), each = M / (D1 * D2)) 
      )
      df_unc[[x1lab]] <- out$grid[[x1nm]]
      df_unc[[x2lab]] <- out$grid[[x2nm]]
      df_list <- append(df_list, list(df_unc))
      labels <- c(labels, "unconstrained")
    }
    plot_df <- bind_rows(df_list)
    
    # out <- reorder_outputs(object, x1nm, x2nm, x3nm, grid=grid, keep_id=keep_id, is_constrained)
    # #M <- out$M
    # plot_df <- data.frame(
    #   domain = 1:M,
    #   muhat = out$muhat, 
    #   lwr = out$lwr, 
    #   upp = out$upp,
    #   #block = if (M == (D1*D2)) rep(1:D2, each = D1) else rep(1:(M / (D1 * D2)), each = (D1 * D2)) 
    #   block = if (M == (D1*D2)) rep(1:D2, each = D1) else rep(1:(D1 * D2), each = M / (D1 * D2)) 
    # )
    
    #plot_df[[x1lab]] <- out$grid[[x1nm]]
    #plot_df[[x2lab]] <- out$grid[[x2nm]]
    
    if(M != (D1*D2)){
      #x3nm <- xnms_add[which(!xnms_add %in% c(x1nm, x2nm))]
      #print (is.character(x3_labels))
      if(!is.logical(x3_labels) || length(x3_labels) != 1){
        #print (x3_labels)
        x3_vals <- x3_labels
        x3_vals_str <- paste(x3_vals, collapse = ", ")
        x3nm_sub <- x3lab %||% x3nm 
        subtitle_text <- paste(
          "Each line segment shows the average",
          ynm, 
          "for ordered", x3nm_sub, "levels:", x3_vals_str
        )
      } else if (x3_labels){
        x3_vals <- sort(unique(object$grid[[x3nm]]))
        x3_vals_str <- paste(x3_vals, collapse = ", ")
        x3nm_sub <- x3lab %||% x3nm 
        subtitle_text <- paste(
          "Each line segment shows the average",
          ynm, 
          "for ordered", x3nm_sub, "levels:", x3_vals_str
        )
      } else {
        subtitle_text <- NULL
      }
      #x3_vals <- x3_labels %||% sort(unique(object$grid[[x3nm]]))
      #x3_vals_str <- paste(x3_vals, collapse = ", ")
      #extras <- list(...)
      #if(exists("subname", where = extras)){
      #  x3nm <- extras$subname
      #}
      # x3nm_sub <- x3lab %||% x3nm 
      # subtitle_text <- paste(
      #   "Each line segment shows the average",
      #   ynm, 
      #   "for ordered", x3nm_sub, "levels:", x3_vals_str
      # )
    } else {
      subtitle_text <- NULL
    }
    # Build base plot
    ylab_text = if (object$family$family %in% c("binomial", "quasibinomial")) "Predicted probability of" else "average response"
    
    # p <- ggplot(plot_df, aes(x = domain, y = muhat)) +
    #   geom_ribbon(aes(ymin = lwr, ymax = upp, group = block), fill = ribbon_fill, alpha = 0.6) +
    #   geom_line(aes(group = block), color = line_color, linewidth = 0.8) +
    #   #geom_line(aes(group = group))
    #   geom_point(size = 1.5) +
    #   geom_vline(xintercept = seq(M / D2, M, by = M / D2), linetype = "dashed", color = "slategray") +
    #   scale_x_continuous(expand = expansion(mult = c(0.01, 0.01))) +
    #   labs(
    #     x = x1lab,
    #     y = if (!is.null(ynm)) paste(ylab_text, "(", ynm, ")") else ylab_text,
    #     subtitle = if(!is.null(subtitle_text)) paste(subtitle_text, "-", type, "fit", sep = " ") else paste(capitalize_first(type), "fit"),
    #     #title = paste(type, "fit", sep = " ")
    #     #caption = if (type == "unconstrained") paste(type, "fit", sep = " ") else NULL
    #   ) +
    #   theme_minimal(base_size = 13) +
    #   theme(
    #     legend.position = "top",
    #     #axis.text.x = element_text(angle = 45, hjust = 1, size = 9),
    #     plot.subtitle = element_text(size = subtitle.size, color = "gray30", hjust = 0)
    #   )
    
    p <- ggplot(plot_df, aes(x = domain, y = muhat, color = fit_type, group = interaction(fit_type, block))) +
      #geom_ribbon(
      #  data = plot_df |> filter(!is.na(lwr) & !is.na(upp)),
      #  aes(ymin = lwr, ymax = upp, fill = fit_type),
      #  alpha = 0.6,
      #  color = NA
      #) +
      geom_line(linewidth = 0.8) +
      geom_point(size = 1.5) +
      geom_vline(xintercept = seq(M / D2, M, by = M / D2), linetype = "dashed", color = "slategray") +
      scale_x_continuous(expand = expansion(mult = c(0.01, 0.01))) +
      scale_color_manual(values = c("constrained" = constrained_color, "unconstrained" = unconstrained_color), name = NULL, na.translate = FALSE) +
      scale_fill_manual(values = c("constrained" = constrained_color, "unconstrained" = unconstrained_color), name = NULL, na.translate = FALSE) +
      #scale_color_manual(name = NULL) +
      #scale_fill_manual(name = NULL) +
      labs(
        x = x1lab,
        y = if (isTRUE(ylab)) {
          if (!is.null(ynm)) paste(ylab_text, ynm) else ylab_text
        } else {
          NULL
        },
        #y = if (!is.null(ynm)) paste(ylab_text, ynm) else ylab_text,
        subtitle = paste(subtitle_text, title_text, sep = "\n")
        #title = title_text
        #subtitle = if (!is.null(subtitle_text) & type != "both") paste(subtitle_text, "-", type, "fit", sep = " ") #else if (type != "both") paste(capitalize_first(type), "fit")
      ) +
      theme_minimal(base_size = 13) +
      theme(
        legend.position = if (isTRUE(legend)) "top" else "none",
        #plot.title = element_text(hjust = 0.5),       # <-- this centers the title
        plot.subtitle = element_text(size = subtitle.size, color = "gray30", hjust = 0.5),
        plot.margin = margin(2, 6, 2, 6)  # top, right, bottom, left (in points)
      )
    
    if(!all(is.na(plot_df$lwr))) {
      p <- p + geom_ribbon(aes(ymin = lwr, ymax = upp, fill = fit_type), alpha = 0.6, color = NA) 
    }
    # 
    # Add annotation for x2 blocks
    x2_positions <- seq(M / D2, M, by = M / D2) - (M / D2) / 2
    
    #test 
    #print (x2_labels_0)
    if(is.character(x2_labels_0)){
      x2_labels <- x2_labels_0 
      x2_labels <- paste(x2lab, '=', x2_labels)
    } else if (x2_labels_0) {
      x2_labels <- unique(object$grid[[x2nm]])
      if(isTRUE(object$replaced[[x2nm]])){
        ps_x2 <- which(colnames(object$replaced) %in% x2nm) 
        x2_labels <- object$levels_old[[ps_x2]]
      }
      x2_labels <- paste(x2lab, '=', x2_labels)
    } else {
      x2_labels <- NULL
    }

    #inv_lwr <- linkinv(object$lwr)
    #inv_upp <- linkinv(object$upp)
    inv_lwr <- if(all(is.na(plot_df$lwr))) plot_df$muhat else plot_df$lwr
    inv_upp <- if(all(is.na(plot_df$upp))) plot_df$muhat else plot_df$upp
    y_range <- range(inv_lwr, inv_upp, na.rm = TRUE)
    
    if(!is.null(x2_labels)){
      p <- p + expand_limits(y = max(inv_upp, na.rm = TRUE) + 0.05 * diff(y_range)) +
        annotate("text",
                 x = x2_positions,
                 y = max(inv_upp, na.rm = TRUE)*1.02,
                 label = x2_labels,
                 size = x2size)
    }
 
    # Annotation for x1 (bottom)
    #x1_labels <- rep(seq_len(D1), times = D2)
    #test
    if(is.character(x1_labels_0)){
      x1_labels <- rep(x1_labels_0, times = D2) 
    } else if (x1_labels_0) {
      x1_labels <- rep(sort(unique(object$xm.red[[x1nm]])), times = D2)
      if(isTRUE(object$replaced[[x1nm]])){
        ps_x1 <- which(colnames(object$replaced) %in% x1nm) 
        x1_labels <- rep(object$levels_old[[ps_x1]], times = D2)
      }
    } else {
      x1_labels <- NULL
    }
    
    if(M == (D1*D2)){
      offset <- 0
    } else {
      offset <- 0.5 * (M / (D1 * D2))
    }
    x1_positions <- rep(seq(M / (D1 * D2), M, by = M / (D1 * D2)), each = 1) - offset
    
    if(!is.null(x1_labels)){
      p <- p + annotate("text",
                        #legend.position = "right",
                        x = x1_positions,
                        y = min(inv_lwr, na.rm = TRUE) - 0.05 * diff(y_range),
                        label = x1_labels,
                        angle = angle,
                        hjust = hjust, 
                        size = x1size)
    }

    #test more: mark NA's or empty domains even when they are imputed
    if(anyNA(plot_df[, c("muhat")]) || any(plot_df[, "nd"] == 0)){
      y_min <- min(plot_df$muhat, plot_df$lwr, plot_df$upp, na.rm = TRUE)
      y_min_offset <- y_min - 0.01 * diff(range(c(plot_df$muhat, plot_df$lwr, plot_df$upp), na.rm = TRUE))
      p <- p + geom_point(
        data = plot_df |> filter(is.na(muhat) | is.na(lwr) | is.na(upp) | nd == 0),
        aes(x = domain, y = y_min_offset), shape = 4, size = 1.5, stroke = 1, color = "red", inherit.aes = FALSE)# +
        #guides(color = guide_legend(override.aes = list(shape = 4, color = "red"))) + 
        #annotate("text", x = Inf, y = -Inf, label = "'x' = empty domain", hjust = 1.1, vjust = -1, color = "red", size = 3.2)
    }

    if(!is.null(domains)){
      if(NCOL(domains) > 1){
        ps0 <- sapply(colnames(domains), function(elem) which(sapply(colnames(object$grid), function(gi) all(gi == elem))))
        #order the column in case the user doesn't define the variables as the order used in the formula
        domains <- domains[, as.numeric(ps0)]
        #ps = apply(newdata, 1, function(elem) which(apply(grid2, 1, function(gi) all(gi == elem))))
      }
      ps <- match_rows(domains, object$grid)
      
      na_rows <- which(is.na(ps))
      if (all(is.na(ps))) {
        stop("None of the domains in 'domains' are defined in the model.")
      } else if (length(na_rows) > 0) {
        removed_rows <- domains[na_rows, , drop = FALSE]
        warning("The following rows were removed because their domains are not defined in the model:\n",
                paste(capture.output(print(removed_rows)), collapse = "\n"))
        domains <- domains[-na_rows, , drop = FALSE]
        ps <- ps[-na_rows]
      }
      
      new_points <- data.frame(
        x = plot_df$domain[ps], #position that new data points match ans$grid
        y = plot_df$muhat[ps],
        lwr = plot_df$lwr[ps],
        upp = plot_df$upp[ps]
      )
      
      p <- p + geom_pointrange(
        data = new_points,
        aes(x = x, y = y, ymin = lwr, ymax = upp),
        color = "red",
        size = 0.8,
        inherit.aes = FALSE
      )
    }
    p
  }
}

Try the csurvey package in your browser

Any scripts or data that you put into this service are public.

csurvey documentation built on June 8, 2025, 12:41 p.m.