R/model_diagram_function.R

Defines functions model_diagram

Documented in model_diagram

#' Hierarchical model diagramming
#'
#' `model_diagram()` takes a hierarchical nested model and returns a DiagrammeR
#' object visualizing the fixed and random effects structure.
#'
#' @param modelObject Input model. Either a lme or merMod (including glmerMod)
#'    object with a nested random effects structure.
#' @param filePath Optional. Path to a location to export the diagram.
#'    Default is `NULL`.
#' @param fileType Optional. File type to export the diagram.
#'    Default is `"PNG"`.
#' @param width Optional. Width of diagram in pixels. Default is `800`.
#' @param height Optional. Height of diagram in pixels. Default is `1600`.
#' @param includeSizes Optional. Include group sizes in random effect labels.
#'    Default is `TRUE`.
#' @param includeLabels Optional. Include labels for the model diagram components.
#'    Default is `TRUE`.
#' @param orientation Optional. Orientation of the diagram, either vertically or
#'    horizontally. Options are `"vertical"`, `"horizontal"`. Default is `"vertical"`.
#' @param scaleFontSize Optional. Proportional font size adjustment for model
#'    diagram component, fixed effect, and random effect labels.
#'    Multiplies these label font sizes by the specified amount. Default is `1`.
#' @param shiftFixed Optional. Additive x-axis adjustment for fixed effect labels,
#'    only used when `orientation == "horizontal"`. Default is `0`.
#' @param shiftRandom Optional. Additive x-axis adjustment for random effect labels,
#'    only used when `orientation == "horizontal"`. Default is `0`.
#' @param nodeColors Optional. Function specifying the colors ([md_color()])
#'    for the outline of the nodes. Components can be specified individually
#'    (`diagram`, `random`, and `fixed`).
#' @param nodeFillColors Optional. Function specifying the colors ([md_fill()])
#'    for the fill color of the nodes. Components can be specified individually
#'    (`diagram`, `random`, and `fixed`).
#' @param nodeFontColors Optional. Function specifying the colors ([md_fontColor()])
#'    for the font color of text in the nodes. Components can be specified
#'    individually (`diagram`, `random`, and `fixed`).
#' @param exportObject Optional. Object(s) to return, `1` for a rendered DiagrammeR
#'    graph as an SVG document, `2` for a `dgr_graph` object, and `3` returns
#'    a list containing both objects.
#'    Default is `1`.
#'
#' @details
#' NOTE: When including this function in an RMD document and knitting to PDF, the
#' graphic size variables `width` and `height` do not currently work. It is
#' recommended that instead the image is exported to a file such as a PDF and then
#' reimported to the document. See the examples below.
#' @returns A rendered DiagrammeR graph as an SVG document (default), or a `dgr_graph`
#' object, or a list containing both (`dgr_graph_obj`, `rendered_graph_obj`).
#' @export
#'
#' @references Greta M. Linse, Mark C. Greenwood, Ronald K. June,
#' *Data-Driven Model Structure Diagrams for Hierarchical Linear Mixed Models*,
#' J. data sci.(2026), 1-21, DOI 10.6339/26-JDS1222
#'
#' @examples
#' # merMod object example
#' library(lme4)
#'
#' sleepstudy_lmer <- lmer(Reaction ~ Days + (Days | Subject), sleepstudy)
#' summary(sleepstudy_lmer)
#' model_diagram(sleepstudy_lmer)
#'
#' # lme object example
#' library(nlme)
#'
#' sleepstudy_lme <- lme(Reaction ~ Days, random=~Days|Subject, data=sleepstudy)
#' summary(sleepstudy_lme)
#' model_diagram(sleepstudy_lme)
#'
#' \donttest{
#' library(rsvg) # required to produce a PDF file
#' temp_path <- tempfile("sleepstudy_lmer_modeldiagram", fileext=".pdf")
#' # Knitting to PDF example
#' model_diagram(sleepstudy_lmer,
#'                filePath= temp_path,
#'                fileType="PDF")
#' }
model_diagram <- function(modelObject, filePath = NULL, fileType = "PNG",
                           width = 800, height = 1600, includeSizes = TRUE,
                           includeLabels = TRUE, orientation = "vertical",
                           scaleFontSize = 1,
                           shiftFixed = 0, shiftRandom = 0,
                           nodeColors = md_color(diagram="gray25", random="gray25", fixed="gray25"),
                           nodeFillColors = md_fill(diagram="aliceblue", random="aliceblue", fixed="darkseagreen1"),
                           nodeFontColors = md_fontColor(diagram="black", random="black",fixed="black"),
                           exportObject = 1){
  if(orientation == "horizontal" & width < height & width==800 & height==1600){
    # Assuming defaults were not changed for size, and change automatically
    warning("Orientation changed to horizontal and default diagram size detected, changing width and height values to match a horizontal orientation.")
    height <- 800
    width <- 1600
  } else if(orientation == "horizontal" & width < height){
    warning("Orientation changed to horizontal and width value less than height value detected, diagram may not be appropriately sized.")
  }
  if(methods::is(modelObject,"merMod")){
    lmer_formula <- deparse1(stats::formula(modelObject),
                                   collapse=" ")
    lmer_formula_fixed <- deparse1(stats::formula(modelObject,
                                                        fixed.only=TRUE),
                                         collapse=" ")
    lmer_formula_random <- substring(stringr::str_remove(lmer_formula,
                                                               stringr::fixed(lmer_formula_fixed)),
                                           4)
    lmer_formula_response <- trimws(substring(lmer_formula_fixed,
                                              first=1,
                                              last = stringr::str_locate(lmer_formula_fixed,
                                                                         stringr::fixed("~"))-1)[[1]])
    if(stringr::str_detect(lmer_formula_random, stringr::fixed("+"))){
      stop("Specification of random effects in multiple parts crossed or otherwise is not yet implemented.")
    }
    if(stringr::str_detect(lmer_formula_fixed, stringr::fixed("+ offset("))){
      lmer_formula_fixed_clean <- substring(lmer_formula_fixed,
                                            first=1,
                                            last=stringr::str_locate(lmer_formula_fixed,
                                                                     stringr::fixed("offset("))[1]-3)
    }  else if(stringr::str_detect(lmer_formula_fixed, stringr::fixed("cbind("))){
      numeratorResponse <- substring(lmer_formula_fixed,
                                     first=7,
                                     last = stringr::str_locate(lmer_formula_fixed,
                                                                stringr::fixed(","))-1)[[1]]
      lmer_formula_fixed_clean <- paste(numeratorResponse,
                                              substring(lmer_formula_fixed,
                                                              first=stringr::str_locate(lmer_formula_fixed,
                                                                                        stringr::fixed(")"))+1,
                                                              last=stringr::str_length(lmer_formula_fixed))[[1]])
    } else{
      lmer_formula_fixed_clean <- lmer_formula_fixed
    }
    if(stringr::str_count(lmer_formula_random, stringr::fixed(":"))>=1){
      lmer_formula_random_nested <- stringr::str_replace(lmer_formula_random,
                                                         stringr::fixed(":"),"/")
      lmer_formula_random_clean <- stringr::str_remove(stringr::str_remove(lmer_formula_random_nested,
                                                                           stringr::fixed("(")),
                                                       stringr::fixed(")"))
      lmer_formula_random_clean_asForm <- stats::as.formula(paste0("~",
                                                                         parse(text=lmer_formula_random_clean)))
    } else {
      lmer_formula_random_clean <- stringr::str_remove(stringr::str_remove(lmer_formula_random,
                                                                           stringr::fixed("(")),
                                                       stringr::fixed(")"))
      lmer_formula_random_clean_asForm <- stats::as.formula(paste0("~",
                                                                         parse(text=lmer_formula_random_clean)))
    }

    lmer_model_data <- stats::model.frame(modelObject)
    if(stringr::str_detect(lmer_formula_fixed, stringr::fixed("cbind("))){
      lmer_formula_response <- trimws(substring(lmer_formula_response,
                                                first=7,
                                                last = stringr::str_locate(lmer_formula_fixed,
                                                                           stringr::fixed(","))-1)[[1]])

      lmer_model_data <- cbind(as.data.frame(stats::model.frame(modelObject)[[1]])[1],
                                     stats::model.frame(modelObject)[-1])
      responseColumn <- which(names(lmer_model_data)==lmer_formula_response) # should always be 1
    } else{
      responseColumn <- which(names(lmer_model_data)==lmer_formula_response) # should always be 1
    }

    # lmer allows for a categorical/factor response variable need to assert as/convert to numeric
    lmer_model_data[responseColumn] <- as.numeric(lmer_model_data[[responseColumn]])

    tryCatch(
      lme_model <- nlme::lme(eval(parse(text=lmer_formula_fixed_clean)),
                       random=lmer_formula_random_clean_asForm,
                       data=lmer_model_data),
      error = function(e){
        message("Could not convert the merMod object to an lme object.")
      })
    if(!exists("lme_model")){
      stop("Invalid specification of random effects. Please check that the random effects are specified in hierarchical notation.")
    }
  } else if(methods::is(modelObject, "lme")){
    lme_model <- modelObject
  } else {
    stop("Please provide a linear mixed effects (lme) model with random effects obtained using the nlme package.",call. = FALSE)
  }

  # Change the label for the lowest level if a generalized linear mixed model
  # which has no random error term at the observation level
  if(methods::is(modelObject, "glmerMod")){
    obsLevel_label <- "Observation Level"
  } else{
    obsLevel_label <- "Observation Error"
  }

  ### Get information from provided model and data

  ## Random effects structure

  RElist <- names(lme_model$dims$qvec)[1:lme_model$dims$Q]

  numDelims <- max(stringr::str_count(lme_model$groups[,RElist[1]],
                                      stringr::fixed("/")))
  numRE <- lme_model$dims$Q

  theseGroups <- lme_model$groups %>%
    tidyr::separate_wider_delim(!!RElist[1],
                                cols_remove = FALSE,
                                delim = "/",
                                names = paste0("RE",1:numRE)) %>%
    dplyr::relocate(paste0("RE",1:numRE), .after = tidyselect::everything())

  if(numRE > 1){
    theseGroups_elipses <- theseGroups %>%
      dplyr::mutate(onesCol = 1)
    reIdx <- numRE
    while(reIdx > 0){
      tempREName <- paste0("RE",reIdx)
      thisRE_levels <- gtools::mixedsort(unique(theseGroups_elipses[,tempREName][[1]]))
      theseRElevels_numeric <- suppressWarnings(as.numeric(thisRE_levels))
      if(sum(as.numeric(!is.na(theseRElevels_numeric)))==length(thisRE_levels)){
        theseGroups_elipses <- theseGroups_elipses %>%
          dplyr::mutate(thisLevel = forcats::fct_relevel(get(tempREName),
                                                         gtools::mixedsort(as.character(theseRElevels_numeric))),
                        numThisLvl = as.numeric(thisLevel)) %>%
          dplyr::arrange(numThisLvl) %>%
          dplyr::select(-c(thisLevel, numThisLvl))
      } else{
        theseGroups_elipses <- theseGroups_elipses %>%
          dplyr::mutate(thisLevel = forcats::fct_relevel(get(tempREName),
                                                         gtools::mixedsort(thisRE_levels)),
                        numThisLvl = as.numeric(thisLevel)) %>%
          dplyr::arrange(numThisLvl) %>%
          dplyr::select(-c(thisLevel, numThisLvl))
      }

      reIdx <- reIdx - 1
    }
    theseGroups_elipses <- theseGroups_elipses %>%
      tibble::rowid_to_column(var="orderID") %>%
      dplyr::group_by(get(RElist[1])) %>%
      dplyr::rename(thisRE = `get(RElist[1])`) %>%
      dplyr::mutate(cumSumCol = cumsum(onesCol),
                    ObsLevel = dplyr::case_when(cumSumCol == 1 ~ "Obs. 1",
                                                cumSumCol == max(cumSumCol, na.rm = TRUE) ~
                                                  paste("Obs.", cumSumCol),
                                                TRUE ~ "...")) %>%
      dplyr::ungroup() %>%
      dplyr::select(-c(onesCol, cumSumCol, thisRE))
  } else if(numRE == 1){
    theseGroups_elipses <- theseGroups %>%
      dplyr::mutate(onesCol = 1,
                    firstLevel = forcats::fct_relevel(get(RElist[1]),
                                                      gtools::mixedsort(levels(get(RElist[1])))),
                    numFirstLvl = as.numeric(firstLevel)) %>%
      dplyr::arrange(numFirstLvl) %>%
      tibble::rowid_to_column(var = "orderID") %>%
      dplyr::select(-c(firstLevel, numFirstLvl)) %>%
      dplyr::group_by(get(RElist[1])) %>%
      dplyr::rename(thisRE = `get(RElist[1])`) %>%
      dplyr::mutate(cumSumCol = cumsum(onesCol),
                    ObsLevel = dplyr::case_when(cumSumCol==1 ~ "Obs. 1",
                                                cumSumCol==max(cumSumCol, na.rm = TRUE) ~
                                                  paste("Obs.",cumSumCol),
                                                TRUE ~ "...")) %>%
      dplyr::ungroup() %>%
      dplyr::select(-c(onesCol, cumSumCol, thisRE))
  } else{
    stop("Model does not have any random effects.")
  }


  theseGroups_names <- names(theseGroups)
  RE_gen_names <- paste0("RE",1:numRE)


  for(i in 1:numRE){
    thisColName <- paste0("RE",i)
    prevColName <- names(theseGroups_elipses)[which(names(theseGroups_elipses) == thisColName)-1]

    names(theseGroups_elipses)[which(names(theseGroups_elipses)==prevColName)] <- "prevRE"
    if(i == 1){
      thisRE_levels <- unique(theseGroups_elipses[,thisColName][[1]])
      theseRElevels_numeric <- suppressWarnings(as.numeric(thisRE_levels))
      if(sum(as.numeric(!is.na(theseRElevels_numeric)))==length(thisRE_levels)){
        maxGrpNum <- length(theseRElevels_numeric)
        firstGrpName <- utils::head(gtools::mixedsort(theseRElevels_numeric),n=1)
        lastGrpName <- utils::tail(gtools::mixedsort(theseRElevels_numeric),n=1)
      } else{
        maxGrpNum <- length(thisRE_levels)
        firstGrpName <- utils::head(gtools::mixedsort(thisRE_levels),n=1)
        lastGrpName <- utils::tail(gtools::mixedsort(thisRE_levels),n=1)
      }

      theseGroups_elipses <- theseGroups_elipses %>%
        dplyr::group_by(get(thisColName)) %>%
        dplyr::rename(thisRE = `get(thisColName)`) %>%
        dplyr::mutate(thisRE = dplyr::case_when(thisRE==firstGrpName &
                                                  prevRE != "..." ~ thisRE,
                                                thisRE==lastGrpName & prevRE != "..."  ~ thisRE,
                                                TRUE ~ "..."),
        )
    } else{
      #this RE has some character labels
      prevNestedName <- RElist[numRE-i + 2]
      fullNestedName <- RElist[numRE-i+1]
      prevNestedOrder <- as.character(unique(theseGroups_elipses[,prevNestedName][[1]]))
      group_subgroup_table <- theseGroups_elipses %>%
        dplyr::mutate(thisRE = get(thisColName),
                      onesCol = 1,
                      fPrevNestedVar = forcats::fct_relevel(get(prevNestedName),
                                                            prevNestedOrder)) %>%
        dplyr::group_by(fPrevNestedVar) %>%
        dplyr::mutate(groupID = dplyr::cur_group_id(),
                      keepThis = dplyr::case_when(prevRE != "..." ~ 1,
                                                  TRUE ~ 0))

      for(j in 1:length(unique(group_subgroup_table$groupID))){
        sub_group_subgroup_table <- group_subgroup_table %>%
          dplyr::filter(groupID==j)
        sortedThisRE <- sub_group_subgroup_table$thisRE
        minThisRE <- sortedThisRE[[1]]
        maxThisRE <- sortedThisRE[[length(sortedThisRE)]]
        for(k in 1:length(sortedThisRE)){
          thisREVal <- sortedThisRE[[k]]
          if(thisREVal != minThisRE & thisREVal != maxThisRE){
            sub_group_subgroup_table$keepThis[[k]] <- 0
          }
        }
        if(j == 1){
          new_group_subgroup_table <- sub_group_subgroup_table
        } else{
          new_group_subgroup_table <- new_group_subgroup_table %>%
            dplyr::bind_rows(sub_group_subgroup_table)
        }
      }
      suppressMessages(
        group_subgroup_table <- new_group_subgroup_table %>%
          dplyr::select(-c(onesCol, fPrevNestedVar, groupID)) %>%
          dplyr::filter(keepThis==1)
      )

      fullNestedName <- RElist[numRE-i+1]
      names(theseGroups_elipses)[which(names(theseGroups_elipses)==thisColName)] <- "thisRE"
      names(theseGroups_elipses)[which(names(theseGroups_elipses)==fullNestedName)] <- "fullNestedName"
      theseGroups_elipses <- theseGroups_elipses %>%
        dplyr::rowwise() %>%
        dplyr::mutate(keepThis = ifelse(fullNestedName %in% group_subgroup_table[RElist[numRE-i+1]][[1]],1,0 )) %>%
        dplyr::ungroup() %>%
        dplyr::mutate(thisRE = dplyr::case_when(keepThis==1 & prevRE != "..." ~ thisRE,
                                                TRUE ~ "..."),
        )

      names(theseGroups_elipses)[which(names(theseGroups_elipses)=="fullNestedName")] <- fullNestedName
    }


    names(theseGroups_elipses)[which(names(theseGroups_elipses)=="prevRE")] <- "thisREold"
    if(thisColName %in% names(theseGroups_elipses)){
      theseGroups_elipses <- theseGroups_elipses %>%
        dplyr::relocate(thisRE, .after=thisREold) %>%
        dplyr::select(-!!thisColName)
    } else{
      theseGroups_elipses <- theseGroups_elipses %>%
        dplyr::relocate(thisRE, .after=thisREold)
    }

    names(theseGroups_elipses)[which(names(theseGroups_elipses)=="thisRE")] <- thisColName
    names(theseGroups_elipses)[which(names(theseGroups_elipses)=="thisREold")] <- prevColName

  }

  for(j in 1:nrow(theseGroups_elipses)){
    obsLevelColNum <- which(names(theseGroups_elipses)=="ObsLevel")
    if(theseGroups_elipses[j,obsLevelColNum-1] == "..."){
      theseGroups_elipses[j,obsLevelColNum] <- "..."
    }
  }

  theseGroups_subset <- theseGroups_elipses %>%
    dplyr::select(-tidyselect::any_of(c("orderID", names(lme_model$groups), "keepThis"))) %>%
    unique()

  theseGroups_subset_posInfo <- theseGroups_subset %>%
    tibble::rowid_to_column(var = "ObsLevelPos") %>%
    dplyr::relocate(ObsLevelPos, .after=tidyselect::everything())

  # Observation level is first - no grouping by anything, just in the order it is
  theseGroups_subset2 <- theseGroups_subset %>%
    tibble::rowid_to_column(var="ObsLevel_rawPos") %>%
    dplyr::relocate(ObsLevel_rawPos, .after=tidyselect::everything())

  for(i in numRE:1){
    thisREname <- paste0("RE",i)
    rawIdxVal <- 1
    theseGroups_subset2$thisRE_raw <- rep(rawIdxVal, nrow(theseGroups_subset2))
    for(j in 2:nrow(theseGroups_subset2)){
      if(theseGroups_subset2[j-1,thisREname][[1]]!=theseGroups_subset2[j,thisREname][[1]]){
        rawIdxVal <- rawIdxVal + 1
      }
      theseGroups_subset2[j,"thisRE_raw"] <- rawIdxVal
    }
    names(theseGroups_subset2)[which(names(theseGroups_subset2)=="thisRE_raw")] <- paste0(thisREname, "_rawPos")
  }
  theseGroups_subset3 <- theseGroups_subset2
  rawColIdx <- rev(which(stringr::str_ends(names(theseGroups_subset2),"rawPos")))

  for(i in rawColIdx){
    thisColName <- names(theseGroups_subset3)[i]
    thisCol <- theseGroups_subset3[[i]]
    if(i==ncol(theseGroups_subset2)){
      theseGroups_subset3$newCol <- thisCol
      maxNextCol <- 0
    } else{
      maxNextCol <- maxNextCol + max(theseGroups_subset3[[i+1]])
      theseGroups_subset3$newCol <- thisCol + maxNextCol
    }

    names(theseGroups_subset3)[which(names(theseGroups_subset3)=="newCol")] <- paste0(stringr::str_remove(thisColName,"rawPos"),"edgePos")
  }

  theseGroups_subset3_clean <- theseGroups_subset3 %>%
    dplyr::select(-tidyselect::ends_with("rawPos")) %>%
    dplyr::as_tibble()

  n_levels <- c()
  for(i in 1:(numRE+1)){
    n_levels[i] <- sum(2^seq(from=0,to=i,by=1))
  }
  n_nodes <- sum(n_levels)

  numReps <- sum(2^seq(from=0, to=numRE, by=1))

  grp_lng_lme_model <- data.frame(name=as.character(),
                                  label=as.character(),
                                  x_pos=as.numeric(),
                                  y_pos=as.numeric())

  names_list <- c(paste0("RE",1:numRE),"ObsLevel")
  reNames_nested <- rep("",length(names_list))
  reNames_nested_wSize <- rep("",length(names_list))
  nested_df <- theseGroups_subset3_clean
  max_RE_width <- max(nchar(names_list))
  max_RE_width_wSize <- max(nchar(names_list))

  for(i in 1:length(names_list)){

    thisName <- names_list[i]
    nested_df$actName <- theseGroups_subset3_clean[[thisName]]

    if(i > 1){
      prevName <- paste0("Nested",i-1)
      nested_df[[thisName]] <- paste0(nested_df[[prevName]],"/",theseGroups_subset3_clean[[thisName]])

    }
    theseLevelsOrdered <- unique(nested_df[[thisName]])
    endPosVal <- length(unique(nested_df[[thisName]]))
    rowRange <- rep(as.numeric(NA), endPosVal)
    names(nested_df)[which(names(nested_df)==thisName)] <- "thisName"
    suppressMessages(
      rowRange <- nested_df %>%
        dplyr::group_by(thisName) %>%
        dplyr::summarize(rowVal = mean(dplyr::cur_group_rows())) %>%
        dplyr::arrange(rowVal) %>%
        dplyr::left_join(nested_df %>% dplyr::select(thisName, actName),
                         multiple="first")
    )

    names(nested_df)[which(names(nested_df)=="thisName")] <- paste0("Nested",i)

    grp_lng_lme_model <- grp_lng_lme_model %>%
      dplyr::bind_rows(data.frame(name=rep(thisName, endPosVal),
                                  label=rowRange$actName,
                                  x_pos=rep((i*2-1), endPosVal),
                                  y_pos=rowRange$rowVal))
    if(i == 1){
      reNames_nested[i] <- names(lme_model$groups)[1]
      reNames_nested_wSize[i] <- paste0(reNames_nested[i], "\nNum.=", lme_model$dims$ngrps[[numRE]])
      max_RE_width <- max(max_RE_width, nchar(reNames_nested[i]))
      max_RE_width_wSize <- max(max_RE_width_wSize, nchar(reNames_nested_wSize[i]))
    } else if(i == length(names_list)){
      reNames_nested[i] <- obsLevel_label
      reNames_nested_wSize[i] <- paste0(reNames_nested[i], "\nNum.=", lme_model$dims$N)
      max_RE_width <- max(max_RE_width, nchar(reNames_nested[i]))
      max_RE_width_wSize <- max(max_RE_width_wSize, nchar(reNames_nested_wSize[i]))
    } else{
      ngrpsCol <- which(names(lme_model$dims$ngrps) == names(lme_model$groups)[i])
      if(i %% 2 == 0){
        reNames_nested[i] <- paste(names(lme_model$groups)[i], "in", reNames_nested[i-1])
      } else{
        reNames_nested[i] <- paste(names(lme_model$groups)[i], "\nin", reNames_nested[i-1])
      }
      reNames_nested_wSize[i] <- paste0(reNames_nested[i],"\nNum.=", lme_model$dims$ngrps[[ngrpsCol]])
      longest_row <- max(nchar(strsplit(reNames_nested[i], "\n")[[1]]))
      longest_row_wSize <- max(nchar(strsplit(reNames_nested_wSize[i], "\n")[[1]]))
      max_RE_width <- max(max_RE_width, longest_row)
      max_RE_width_wSize <- max(max_RE_width_wSize, longest_row_wSize)
    }
  }


  if(includeSizes){
    grp_lng_lme_model <- grp_lng_lme_model %>%
      dplyr::bind_rows(data.frame(name=rep("RE Label", numRE+1),
                                  label=reNames_nested_wSize,
                                  x_pos=seq(from=1, to=numRE*2+1,by=2),
                                  y_pos=max(grp_lng_lme_model$y_pos)+2))
  } else{
    grp_lng_lme_model <- grp_lng_lme_model %>%
      dplyr::bind_rows(data.frame(name=rep("RE Label", numRE+1),
                                  label=reNames_nested,
                                  x_pos=seq(from=1, to=numRE*2+1,by=2),
                                  y_pos=max(grp_lng_lme_model$y_pos)+2))
  }


  if(methods::is(modelObject,"merMod")){
    fixedLevelInfo <- getFixLevel(lme_model,
                                  fixedCall = parse(text=lmer_formula_fixed_clean),
                                  randomCall = parse(text=lmer_formula_random_clean),
                                  obsLevelLabel = obsLevel_label)
  } else if(methods::is(modelObject, "lme")){
    fixedLevelInfo <- getFixLevel(lme_model,
                                  fixedCall = lme_model$call$fixed,
                                  randomCall = lme_model$call$random,
                                  obsLevelLabel = obsLevel_label)
  }

  namesFixed <- names(fixedLevelInfo)
  numFixed <- length(namesFixed)
  namesFixed[numFixed] <- obsLevel_label
  fixed_RElevel <- rep(numRE+1,numFixed)
  fixedPlaceholders <- rep(" ", length(names_list))

  reNamesPlusOE <- c(names(lme_model$groups), obsLevel_label)

  if(numFixed > 0){
    max_FE_width <- max(nchar(namesFixed))

    for(i in 1:(numRE+1)){
      thisREName <- reNamesPlusOE[i]
      fixedNameList <- names(fixedLevelInfo[fixedLevelInfo==thisREName])
      if(length(fixedNameList) == 0){
        fixedPlaceholders[i] <- " "
      } else{
        fixedPlaceholders[i] <- paste(fixedNameList, collapse="\n")
      }
    }

  } else{
    max_FE_width <- max_RE_width
  }

  grp_lng_lme_model <- grp_lng_lme_model %>%
    dplyr::bind_rows(data.frame(name = rep("Fixed Label", numRE+1),
                         label = fixedPlaceholders,
                         x_pos = seq(from = 1, to = numRE*2+1, by = 2),
                         y_pos = max(grp_lng_lme_model$y_pos) - 1))

  labelOffset <- 0
  topOnlyTwoLevels <- FALSE
  numTopRELevels <- 3

  if(includeLabels){
    grp_lng_lme_model$x_pos <- grp_lng_lme_model$x_pos + 1
    get_y_pos <- grp_lng_lme_model %>%
      dplyr::filter(name=="RE1" & label=="...") %>%
      dplyr::select(y_pos) %>%
      dplyr::pull()
    if(length(get_y_pos)==0){
      get_y_pos <- grp_lng_lme_model %>%
        dplyr::filter(name=="RE1") %>%
        dplyr::summarize(meanVal = mean(y_pos)) %>%
        dplyr::select(meanVal) %>%
        dplyr::pull()
      topOnlyTwoLevels <- TRUE
      numTopRELevels <- 2
    }
    grp_lng_lme_model <- grp_lng_lme_model %>%
      dplyr::bind_rows(data.frame(name=rep("Diagram Label", 3),
                           label=c("Measurement\nDiagram", "Fixed", "Random"),
                           x_pos=1,
                           y_pos=c(get_y_pos,
                                   max(grp_lng_lme_model$y_pos) - 1,
                                   max(grp_lng_lme_model$y_pos))))
    labelOffset <- 3
    x_axis_adjust <- max(max_RE_width_wSize, max_FE_width)/(min(max_RE_width_wSize, max_FE_width))
  } else{
    x_axis_adjust <- max(max_RE_width, max_FE_width)/(min(max_RE_width, max_FE_width))
  }


  md_nodes <- grp_lng_lme_model %>%
    dplyr::filter(!(name %in% c("RE Label", "Fixed Label", "Diagram Label")))

  fixed_nodes <- grp_lng_lme_model %>%
    dplyr::filter(name == "Fixed Label")

  random_nodes <- grp_lng_lme_model %>%
    dplyr::filter(name == "RE Label")

  if(includeLabels){
    md_label_node <- grp_lng_lme_model %>%
      dplyr::filter(name == "Diagram Label" & label == "Measurement\nDiagram")
    fixed_label_node <- grp_lng_lme_model %>%
      dplyr::filter(name == "Diagram Label" & label == "Fixed")
    random_label_node <- grp_lng_lme_model %>%
      dplyr::filter(name == "Diagram Label" & label == "Random")
  } else{
    md_label_node <- c()
  }

  useFontName <- "Arial"
  tooltip <- NULL

  suppressMessages(nodes_fixed <- DiagrammeR::create_node_df(
    n = nrow(fixed_nodes),
    type = "fixed",
    label = fixed_nodes$label,
    style = "filled",
    color = nodeColors$fixed,
    fillcolor = nodeFillColors$fixed,
    fontname = useFontName,
    fontcolor = nodeFontColors$fixed,
    fontsize = 10,
    shape = "box",
    DiagrammeR::node_aes(
      x = fixed_nodes$x_pos,
      y = fixed_nodes$y_pos,
      fixedsize = FALSE,
      tooltip = fixed_nodes$label,
    ),
    DiagrammeR::node_data(
      value = 1:nrow(fixed_nodes)
    )
  ))

  suppressMessages(nodes_random <- DiagrammeR::create_node_df(
    n = nrow(random_nodes),
    type = "random",
    label = random_nodes$label,
    style = "filled",
    color = nodeColors$random,
    fillcolor = nodeFillColors$random,
    fontname = useFontName,
    fontcolor = nodeFontColors$random,
    fontsize = 10,
    shape = "box",
    DiagrammeR::node_aes(
      x = random_nodes$x_pos,
      y = random_nodes$y_pos,
      fixedsize = FALSE,
      tooltip = random_nodes$label,
    ),
    DiagrammeR::node_data(
      value = 1:nrow(random_nodes)
    )
  ))
  suppressMessages(nodes_md <- DiagrammeR::create_node_df(
    n = nrow(md_nodes),
    type = "md",
    label = md_nodes$label,
    style = "filled",
    color = nodeColors$diagram,
    fillcolor = nodeFillColors$diagram,
    fontname = useFontName,
    fontcolor = nodeFontColors$diagram,
    shape = "circle",
    DiagrammeR::node_aes(
      x = md_nodes$x_pos,
      y = md_nodes$y_pos,
      fixedsize = FALSE,
      tooltip = md_nodes$label,
    ),
    DiagrammeR::node_data(
      value = 1:nrow(md_nodes)
    )
  ))

  suppressMessages(nodes_lme_model <- DiagrammeR::combine_ndfs(nodes_md, nodes_fixed,
                                                               nodes_random) %>%
                     dplyr::mutate(x = x * x_axis_adjust,
                                   fixedsize = FALSE))

  if(orientation=="horizontal"){
    nodes_lme_model <- nodes_lme_model %>%
      dplyr::mutate(x_old = x,
                    x = y,
                    y = -x_old)
  }
  setNames <- paste0("Set", 1:numRE)
  groupSets <- list(rep(list(), numRE))

  for(i in 1:numRE){
    namesRE <- names(theseGroups_subset3_clean)[c(i,i+1)]
    names_edgePos <- names(theseGroups_subset3_clean)[which(names(theseGroups_subset3_clean) %in% paste0(namesRE, "_edgePos"))]
    groups_subset <- theseGroups_subset3_clean[,c(namesRE, names_edgePos)] %>%
      unique() %>%
      dplyr::rename(fromLabel=!!namesRE[1],
                    toLabel=!!namesRE[2],
                    fromPos=!!names_edgePos[1],
                    toPos=!!names_edgePos[2])
    groupSets[[i]] <- groups_subset

  }

  group_edges_df <- groupSets[[1]]
  if(numRE > 1){
    for(i in 2:numRE){
      group_edges_df <- group_edges_df %>%
        dplyr::bind_rows(groupSets[[i]])
    }
  }


  edges_lme_model <-
    DiagrammeR::create_edge_df(
      from = group_edges_df$fromPos,
      to = group_edges_df$toPos,
      rel = "requires",
      color = "black")
  graph_lme_model <- DiagrammeR::create_graph(nodes_df=nodes_lme_model,
                                              edges_df=edges_lme_model)


  if(includeLabels){
    suppressMessages(graph_lme_model <- graph_lme_model %>%
                     DiagrammeR::add_n_nodes(
                       n = 1,
                       type = "md_label",
                       label = md_label_node$label,
                       DiagrammeR::node_aes(
                         style = "filled",
                         color = nodeColors$diagram,
                         fillcolor = nodeFillColors$diagram,
                         fontname = useFontName,
                         fontcolor = nodeFontColors$diagram,
                         fontsize = 10,
                         shape = "box",
                         x = md_label_node$x_pos,
                         y = md_label_node$y_pos,
                         fixedsize = FALSE,
                         tooltip = md_label_node$label,
                       ))  %>%
                       DiagrammeR::add_n_nodes(
                         n = 1,
                         type = "fixed_label",
                         label = fixed_label_node$label,
                         DiagrammeR::node_aes(
                           style = "filled",
                           color = nodeColors$fixed,
                           fillcolor = nodeFillColors$fixed,
                           fontname = useFontName,
                           fontcolor = nodeFontColors$fixed,
                           fontsize = 10,
                           shape = "box",
                           x = fixed_label_node$x_pos,
                           y = fixed_label_node$y_pos,
                           fixedsize = FALSE,
                           tooltip = fixed_label_node$label,
                         ))  %>%
                       DiagrammeR::add_n_nodes(
                         n = 1,
                         type = "random_label",
                         label = random_label_node$label,
                         DiagrammeR::node_aes(
                           style = "filled",
                           color = nodeColors$random,
                           fillcolor = nodeFillColors$random,
                           fontname = useFontName,
                           fontcolor = nodeFontColors$random,
                           fontsize = 10,
                           shape = "box",
                           x = random_label_node$x_pos,
                           y = random_label_node$y_pos,
                           fixedsize = FALSE,
                           tooltip = random_label_node$label,
                         ))
    )
  }
  if(orientation=="horizontal"){
    graph_lme_model <- graph_lme_model %>%
      DiagrammeR::select_nodes() %>%
      DiagrammeR::mutate_node_attrs_ws(fixedsize = FALSE,
                                       fontsize = fontsize * scaleFontSize,
                                       tooltip = ifelse(is.na(tooltip)," ",tooltip)) %>%
      DiagrammeR::clear_selection() %>%
      DiagrammeR::select_nodes(
        conditions = type %in% c("md_label","fixed_label", "random_label")) %>%
      DiagrammeR::mutate_node_attrs_ws(
        x_old = x,
        x = y,
        y = -x_old
      ) %>%
      DiagrammeR::clear_selection() %>%
      DiagrammeR::select_nodes(
        conditions = type %in% c("fixed", "fixed_label")) %>%
      DiagrammeR::nudge_node_positions_ws(
        dx = shiftFixed, dy = 0) %>%
      DiagrammeR::clear_selection() %>%
      DiagrammeR::select_nodes(
        conditions = type %in% c("random", "random_label")) %>%
      DiagrammeR::nudge_node_positions_ws(
        dx = shiftRandom, dy = 0)
  } else{
    graph_lme_model <- graph_lme_model %>%
      DiagrammeR::select_nodes() %>%
      DiagrammeR::mutate_node_attrs_ws(fixedsize = FALSE,
                                       fontsize = fontsize * scaleFontSize,
                                       tooltip = ifelse(is.na(tooltip)," ",tooltip))
  }

  if(!is.null(filePath)){
    suppressWarnings(
      graph_lme_model %>%
        DiagrammeR::export_graph(file_name = filePath,
                                 file_type = fileType,
                                 width=width,height=height)
    )
  }
  if(exportObject == 1){
    graph_lme_model %>%
      DiagrammeR::render_graph(width=width,height=height,
                               output="graph", as_svg=TRUE)
  } else if(exportObject == 2){
    return(graph_lme_model)
  } else{

    rg_svg <- graph_lme_model %>%
      DiagrammeR::render_graph(width=width,height=height,
                               output="graph", as_svg=TRUE)

    return(list(dgr_graph_obj = graph_lme_model, rendered_graph_obj = rg_svg))

  }


}

# https://stackoverflow.com/questions/24129124/how-to-determine-if-a-character-vector-is-a-valid-numeric-or-integer-vector
# Post by Adriano Rivolli on Sep 21, 2017 at 16:30 accessed on Dec 10, 2024 at 4:30 pm

#' Catch numeric values stored as strings
#'
#' `catchNumeric()` takes a character list and determines if it is a valid numeric list.
#' Adapted from: https://stackoverflow.com/questions/24129124/how-to-determine-if-a-character-vector-is-a-valid-numeric-or-integer-vector
#' Post by Adriano Rivolli on Sep 21, 2017 at 16:30 accessed on Dec 10, 2024 at 4:30 pm
#'
#' @keywords internal
#' @param strList A character vector
#'
#' @returns A list of vectors, with each vector having the correct type, either character or numeric.
#' @export
#'
#' @examples
#' charList <- c("A","B","C")
#' catchNumeric(charList)
#'
#' numList <- c("1","2","3")
#' catchNumeric(numList)
#'
#' mixList <- c("A","2","C","4")
#' catchNumeric(mixList)
catchNumeric <- function(strList) {
  numlist <- suppressWarnings(as.numeric(strList))
  fixedList <- as.list(strList)
  fixedList[!is.na(numlist)] <- numlist[!is.na(numlist)]
  fixedList
}

#' Identify the fixed effect model structure level
#'
#' `getFixLevel()` uses the degrees of freedom information from `lme()`
#' to identify the correct placement of each fixed effect and interaction term
#' in a model structure diagram for hierarchical mixed effects models.
#'
#' @keywords internal
#' @param lme_model A `nlme` lme model object inherited from `model_diagram()`
#' @param fixedCall The fixedCall object from an lme model or parsed text
#'    containing fixed variable names from a merMod object
#' @param randomCall The randomCall object from an lme model or parsed text
#'    containing the random effect structure from a merMod object
#' @param obsLevelLabel The label that should be used for the observational level
#'
#' @returns The random effect at which the fixed effect is placed at for degrees
#'    of freedom calculations.
#' @export
#'
#' @examples
#' library(nlme)
#' library(lme4) # For sleepstudy data
#' sleepstudy_lme <- lme(Reaction ~ Days, random=~Days|Subject, data=sleepstudy)
#' getFixLevel(sleepstudy_lme, fixedCall = sleepstudy_lme$call$fixed,
#'             randomCall = sleepstudy_lme$call$random,
#'             obsLevelLabel = "Observation Error")
getFixLevel <- function(lme_model, fixedCall, randomCall, obsLevelLabel){

  data <- lme_model$data

  this.call <- as.list(match.call())[-1L]
  this.call$fixed <- eval(fixedCall)
  fixed <- this.call$fixed

  if (!inherits(fixed, "formula") || length(fixed) != 3) {
    stop("fixed-effects model must be a formula of the form \"resp ~ pred\"")
  }
  method <- "REML"
  REML <- method == "REML"
  if(methods::is(randomCall,"call")){
    random <- eval(randomCall)
  } else if(methods::is(randomCall,"expression")){
    random <- eval(parse(text=paste0("~", randomCall)))
  } else{
    stop(paste("randomCall has a class that is", class(randomCall)))
  }

  random <- nlme::reStruct(stats::formula(random), data = NULL)
  reSt <- nlme::reStruct(random, REML = REML, data = NULL)
  groups <- nlme::getGroupsFormula(reSt)

  correlation <- NULL

  weights <- NULL
  na.action <- stats::na.fail

  corQ <- lmeQ <- 1

  ## create an lme structure containing the random effects model and plug-ins
  lmeSt <- nlme::lmeStruct(reStruct = reSt, corStruct = correlation,
                           varStruct = nlme::varFunc(weights))

  ## extract a data frame with enough information to evaluate
  ## fixed, groups, reStruct, corStruct, and varStruct
  mfArgs <- list(formula = nlme::asOneFormula(stats::formula(lmeSt), fixed, groups),
                 data = data, na.action = na.action)

  mfArgs$drop.unused.levels <- TRUE
  dataMix <- do.call(stats::model.frame, mfArgs)

  origOrder <- row.names(dataMix)	# preserve the original order

  ## sort the model.frame by groups and get the matrices and parameters
  ## used in the estimation procedures
  grps <- nlme::getGroups(dataMix, groups)
  ## ordering data by groups
  if (inherits(grps, "factor")) {	# single level
    ord <- order(grps)	#"order" treats a single named argument peculiarly
    grps <- data.frame(grps)
    row.names(grps) <- origOrder
    names(grps) <- as.character(deparse((groups[[2L]])))
  } else {
    ord <- do.call(order, grps)
    ## making group levels unique
    for(i in 2:ncol(grps)) {
      grps[, i] <-
        as.factor(paste(as.character(grps[, i-1]),
                        as.character(grps[, i  ]), sep = "/"))
    }
  }
  if (corQ > lmeQ) {
    ## may have to reorder by the correlation groups
    ord <- do.call(order, nlme::getGroups(dataMix,
                                          nlme::getGroupsFormula(correlation)))
  }
  grps <- grps[ord, , drop = FALSE]
  dataMix <- dataMix[ord, ,drop = FALSE]
  revOrder <- match(origOrder, row.names(dataMix)) # putting in orig. order

  ## obtaining basic model matrices
  N <- nrow(grps)
  Z <- stats::model.matrix(reSt, dataMix)      # stores contrasts in matrix form
  ncols <- attr(Z, "ncols")
  nlme::Names(lmeSt$reStruct) <- attr(Z, "nams")
  ## keeping the contrasts for later use in predict
  contr <- attr(Z, "contr")
  X <- stats::model.frame(fixed, dataMix)

  ngrps <- lme_model$dims$ngrps
  Terms <- attr(X, "terms")
  if (length(attr(Terms, "offset")))
    stop("offset() terms are not supported")
  auxContr <- lapply(X, function(el)
    if (inherits(el, "factor") &&
        length(levels(el)) > 1) stats::contrasts(el))
  contr <- lme_model$contrasts
  contr <- c(contr, auxContr[is.na(match(names(auxContr), names(contr)))])
  contr <- contr[!unlist(lapply(contr, is.null))]
  X <- stats::model.matrix(fixed, data=X)
  y <- eval(fixed[[2L]], dataMix)
  ncols <- c(ncols, dim(X)[2L], 1)
  Q <- ncol(grps)

  # Below is adapted from getFixDF() from nlme's lme.R script
  assign <- attr(X, "assign")
  terms <- Terms
  ## calculates degrees of freedom for fixed effects Wald tests
  if (!is.list(assign)) {               # in R
    namTerms <- attr(terms, "term.labels")
    if (attr(terms, "intercept") > 0) {
      namTerms <- c("(Intercept)", namTerms)
    }
    namTerms <- factor(assign, labels = namTerms)
    assign <- split(order(assign), namTerms)
  }
  ## function to check if a vector is (nearly) a multiple of (1,1,...,1)
  const <- function(x, tolerance = sqrt(.Machine$double.eps)) {
    if (length(x) < 1) return(NA)
    x <- as.numeric(x)
    ## return
    all(abs(if(x[1L] == 0) x else x/x[1L] - 1) < tolerance)
  }
  N <- nrow(X)
  p <- ncol(X)
  Q <- ncol(grps)
  Qp1 <- Q + 1L
  namX <- colnames(X)
  ngrps <- rev(ngrps)[-(1:2)]
  stratNam <- c(names(ngrps), obsLevelLabel)

  dfX <- dfTerms <- stats::setNames(c(ngrps, N) - c(0, ngrps), stratNam)
  valX <- stats::setNames(double(p), namX)
  namTerms <- names(assign)
  valTerms <- double(length(assign))
  names(valTerms) <- namTerms
  if (any(notIntX <- !apply(X, 2, const))) {
    ## percentage of groups for which columns of X are inner
    innP <- array(c(rep(1,p),
                    inner_perc_table_R(X, grps)),
                  dim = c(p, Qp1),
                  dimnames = list(namX, stratNam))
    ## strata in which columns of X are estimated
    ## ignoring fractional inner percentages for now
    stratX <- stratNam[apply(innP, 1, function(el, index) max(index[el > 0]),
                             index = 1:Qp1)]
    ## strata in which terms are estimated
    notIntTerms <- unlist(lapply(assign,
                                 function(el, notIntX) {
                                   any(notIntX[el])
                                 }, notIntX = notIntX))
    stratTerms <- stratNam[unlist(lapply(assign,
                                         function(el) max(match(stratX[el], stratNam))
    ))][notIntTerms] # Can stop here for determining the level of Fixed effects
    names(stratTerms) <- namTerms[-1]
    # stratX <- stratX[notIntX]
    # names(stratX) <- namX[-1]
    return(stratTerms)
  }
}


#
# Code translation assistance provided by Anthropic's Claude 3.5 Sonnet (2024 version) on December 17, 2024.
# translated inner_perc() and inner_perc_table() from nlme's nlmefit.c file (https://github.com/cran/nlme/blob/master/src/nlmefit.c)
#

# # static double
# inner_perc(double *x, int *grp, int n)
# /* percentage of groups for which x is inner */
#   {
#     /* x - column of X matrix to be assessed
#     grp - integer vector with groups
#     n - length of x and grp
#     data are assumed to be ordered by grp */
#
#       int currGrp, nn = 0, isInner;
#     double nInner = 0., nGrp = 0., currVal;
#
#     while (nn < n) {
#       currGrp = grp[nn];
#       currVal = x[nn];
#       nGrp++;
#       isInner = 0;
#       do {
#         if (isInner == 0 && x[nn] != currVal) {
#           nInner++;
#           isInner = 1;
#         }
#         nn++;
#       } while (nn < n && currGrp == grp[nn]);
#     }
#     return(nInner/nGrp);
#   }

#' Calculates inner group percentage
#'
#' `inner_perc_R()` Calculates the percentage of groups for which the fixed
#'  effect `x` is "inner". Data are assumed to be ordered by `grp`.
#'
#'  Translated `inner_perc()` from nlme's nlmefit.c file (https://github.com/cran/nlme/blob/master/src/nlmefit.c)
#
#'  Code translation assistance provided by Anthropic's Claude 3.5 Sonnet (2024 version) on December 17, 2024.
#'
#' @keywords internal
#' @param x Column of X matrix to be assessed
#' @param grp Integer vector with groups
#'
#' @returns Numeric vector of percentages
#' @export
#'
#' @examples
#' library(lme4)
#' data(sleepstudy)
#' # When used in model_diagram() group labels ("grps") should be unique
#' inner_perc_R(sleepstudy[,"Days"],sleepstudy[,"Subject"])
inner_perc_R <- function(x, grp){

  if (length(x) != length(grp)) {
    stop("x and grp must have the same length")
  }

  n <- length(x)
  nInner <- 0
  nGrp <- 0
  nn <- 1

  while(nn <= n){
    currGrp <- grp[[nn]]
    currVal <- x[[nn]]
    nGrp <- nGrp + 1
    isInner <- FALSE

    while(nn <= n & currGrp == grp[nn]){
      if(!isInner & x[[nn]] != currVal){
        nInner <- nInner + 1
        isInner <- TRUE
      }
      nn <- nn + 1
    }


  }
  return(nInner/nGrp)
}


#' Table of inner group percentages for each fixed effect variable
#'
#' `inner_perc_table_R()` Calculates the inner group percentage for each fixed
#'  effect variable.
#'
#'  Translated `inner_perc_table()` from nlme's nlmefit.c file (https://github.com/cran/nlme/blob/master/src/nlmefit.c)
#
#'  Code translation assistance provided by Anthropic's Claude 3.5 Sonnet (2024 version) on December 17, 2024.
#'
#' @keywords internal
#' @param X Matrix or data frame from lme object
#' @param grps Matrix or data frame where each column is a grouping variable
#' @param p If `NULL` the number of columns of `X`, else a pre-specified number
#'        of fixed effects.
#' @param Q If `NULL` the number of columns in `grps`, else a pre-specified number
#'        of random effects or groups.
#'
#' @returns A matrix of inner group proportions.
#' @export
#'
#' @examples
#' library(lme4)
#' data(sleepstudy)
#' # When used in model_diagram() group labels ("grps") should be unique
#' inner_perc_table_R(sleepstudy[,c("Reaction","Days")],
#'                    as.data.frame(sleepstudy[,"Subject"]))
inner_perc_table_R <- function(X, grps, p = NULL, Q = NULL){

  # Input validation and preparation
  if (is.null(p)){
    p <- ncol(X)
  }
  if (is.null(Q)){
    Q <- ncol(grps)
  }

  if (!is.matrix(X) & !is.data.frame(X)){
    X <- as.matrix(X)
  }
  if (!is.matrix(grps) & !is.data.frame(grps)){
    grps <- as.matrix(grps)
  }

  # Validate dimensions
  if (nrow(X) != nrow(grps)) {
    stop("X and grps must have the same number of rows")
  }
  if (ncol(grps) != Q) {
    stop("Number of grouping variables doesn't match Q")
  }
  if (ncol(X) != p) {
    stop("Number of columns in X doesn't match p")
  }

  # Initialize results matrix
  pTable <- matrix(0, nrow = p, ncol = Q)

  # Calculate inner percentages
  for (i in 1:Q) {
    for (j in 1:p) {
      pTable[j, i] <- inner_perc_R(X[, j], grps[, i])
    }
  }

  # Add row and column names if they exist
  if (!is.null(colnames(X))){
    rownames(pTable) <- colnames(X)
  }
  if (!is.null(colnames(grps))){
    colnames(pTable) <- colnames(grps)
  }

  return(pTable)

}

#' Diagram elements
#'
#' @description
#' Similar to [ggplot2::ggplot()]'s [ggplot2::theme()] system, the `md_` functions
#' specify the display of how the node components of the diagram are drawn.
#'
#'   - `md_color()`: border of the node.
#'   - `md_fill()`: fill color of the node.
#'   - `md_fontColor()`: font color of the text in the nodes.
#'
#'
#' @param diagram Specifies the outline color, fill color, or font color
#'    for the elements in the measurement diagram circles (hierarchical diagram).
#'    Default is `"gray25"` for ([md_color()]), `"aliceblue"` for ([md_fill()]),
#'    and `"black"` for ([md_fontColor()]).
#' @param random Specifies the outline color, fill color, or font color
#'    for the elements in the random effect variable boxes.
#'    Default is `"gray25"` for ([md_color()]), `"aliceblue"` for ([md_fill()]),
#'    and `"black"` for ([md_fontColor()]).
#' @param fixed Specifies the outline color, fill color, or font color
#'    for the elements in the fixed effect variable boxes.
#'    Default is `"gray25"` for ([md_color()]), `"darkseagreen1"` for ([md_fill()]),
#'    and `"black"` for ([md_fontColor()]).
#'
#' @examples
#' # merMod object example
#' library(lme4)
#'
#' sleepstudy_lmer <- lmer(Reaction ~ Days + (Days | Subject), sleepstudy)
#' summary(sleepstudy_lmer)
#' model_diagram(sleepstudy_lmer,
#'   nodeColors = md_color(diagram="steelblue", random="steelblue", fixed="darkolivegreen1"),
#'   nodeFillColors = md_fill(diagram="aliceblue", random="aliceblue", fixed="darkseagreen1"),
#'   nodeFontColors = md_fontColor(diagram="black", random="blue", fixed="black"))
#'
#' @name md
#' @aliases NULL
NULL

#' @returns A list object containing color specifications for the border of the nodes.
#' @export
#' @rdname md
md_color <- function(diagram="gray25", random="gray25", fixed="gray25") {

  structure(
    list(diagram = diagram, random=random, fixed=fixed),
    class = c("md_color", "md")
  )
}

#' @returns A list object containing color specifications for the fill of the nodes.
#' @export
#' @rdname md
md_fill <- function(diagram="aliceblue", random="aliceblue", fixed="darkseagreen1") {

  structure(
    list(diagram = diagram, random=random, fixed=fixed),
    class = c("md_color", "md")
  )
}

#' @returns A list object containing color specifications for the font color of the nodes.
#' @export
#' @rdname md
md_fontColor <- function(diagram="black", random="black", fixed="black") {

  structure(
    list(diagram = diagram, random=random, fixed=fixed),
    class = c("md_color", "md")
  )
}

Try the modeldiagramR package in your browser

Any scripts or data that you put into this service are public.

modeldiagramR documentation built on April 15, 2026, 5:07 p.m.