geom_node_plot: Draw plots at nodes

View source: R/geom_node_plot.R

geom_node_plotR Documentation

Draw plots at nodes

Description

Additional component for a [ggparty()] that allows to create in each node a ggplot with its data. #'

Usage

geom_node_plot(
  plot_call = "ggplot",
  gglist = NULL,
  width = 1,
  height = 1,
  size = 1,
  ids = "terminal",
  scales = "fixed",
  nudge_x = 0,
  nudge_y = 0,
  shared_axis_labels = FALSE,
  shared_legend = TRUE,
  predict = NULL,
  predict_gpar = NULL,
  legend_separator = FALSE
)

Arguments

plot_call

Any function that generates a 'ggplot2' object.

gglist

List of additional 'gg' components. Columns of 'data' of nodes can be mapped. Additionally 'fitted_values' and 'residuals' can be mapped if present in 'party' of 'ggparty()'

width

Expansion factor for viewport's width.

height

Expansion factor for viewport's height.

size

Expansion factor for viewport's size.

ids

Id's to plot. Numeric, "terminal", "inner" or "all". Defaults to "terminal".

scales

See [ggplot2::facet_wrap()]

nudge_x, nudge_y

Nudges node plot.

shared_axis_labels

If TRUE only one pair of axes labels is plotted in the terminal space. Only recommended if 'ids' "terminal" or "all".

shared_legend

If 'TRUE' one shared legend is plotted at the bottom of the tree.

predict

Character string specifying variable for which predictions should be plotted.

predict_gpar

Named list containing arguments to be passed to the 'geom_line()' call of predicted values.

legend_separator

If 'TRUE' line between legend and tree is drawn.

See Also

[ggparty()]

Examples


library(ggparty)

airq <- subset(airquality, !is.na(Ozone))
airct <- ctree(Ozone ~ ., data = airq)

ggparty(airct, horizontal = TRUE, terminal_space = 0.6) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist = list(
    geom_density(aes(x = Ozone))),
    shared_axis_labels = TRUE)

#############################################################

## Plot with ggparty


## Demand for economics journals data
data("Journals", package = "AER")
Journals <- transform(Journals,
                      age = 2000 - foundingyear,
                      chars = charpp * pages)

## linear regression tree (OLS)
j_tree <- lmtree(log(subs) ~ log(price/citations) | price + citations +
                   age + chars + society, data = Journals, minsize = 10, verbose = TRUE)

pred_df <- get_predictions(j_tree, ids = "terminal", newdata =  function(x) {
  data.frame(
    citations = 1,
    price = exp(seq(from = min(x$`log(price/citations)`),
                    to = max(x$`log(price/citations)`),
                    length.out = 100)))
})

ggparty(j_tree, terminal_space = 0.8) +
  geom_edge() +
  geom_edge_label() +
  geom_node_splitvar() +
  geom_node_plot(gglist =
                   list(aes(x = `log(price/citations)`, y = `log(subs)`),
                        geom_point(),
                        geom_line(data = pred_df,
                                  aes(x = log(price/citations),
                                      y = prediction),
                                  col = "red")))

ggparty documentation built on Aug. 8, 2025, 6:45 p.m.