R/EM_plot.R

Defines functions plot.EM

Documented in plot.EM

#' Diagnostic plots for EM fits
#'
#' Produce diagnostic plots of EM fits returned from [mixfit()].
#'
#' @param x EM fit
#' @param size Optional argument passed to `ggplot2` routines
#' which control line thickness.
#' @param link Choice of an applied link function. Can take one of the
#' values `identity` (default), `logit` or `log`.
#' @param ... Ignored.
#'
#' Overlays the fitted mixture density with a histogram and a density
#' plot of the raw sample fitted. Applying a link function can be
#' beneficial, for example a `logit` (`log`) link for beta
#' (gamma) mixtures obtained from a Binomial (Poisson)
#' [gMAP()] analysis.
#'
#' @template plot-help
#'
#' @return A list of [ggplot2::ggplot()] plots for
#' diagnostics of the EM run. Detailed EM diagnostic plots are
#' included only if the global option `RBesT.verbose` is set to
#' `TRUE`. These include plots of the parameters of each
#' component vs the iteration. The plot of the mixture density with a
#' histogram and a density of the fitted sample is always returned.
#'
#' @family EM
#'
#' @examples
#'
#' bmix <- mixbeta(rob = c(0.2, 1, 1), inf = c(0.8, 10, 2))
#' bsamp <- rmix(bmix, 1000)
#' bfit <- mixfit(bsamp, type = "beta", Nc = 2)
#' pl <- plot(bfit)
#'
#' print(pl$mixdens)
#' print(pl$mix)
#'
#' \donttest{
#' # a number of additional plots are generated in verbose mode
#' .user_option <- options(RBesT.verbose = TRUE)
#' pl_all <- plot(bfit)
#'
#' # recover previous user options
#' options(.user_option)
#'
#' names(pl_all)
#' # [1] "mixdist" "a"   "b"   "w"   "m"   "N"   "Lm"  "lN"  "Lw"  "lli" "mixdens" "mixecdf" "mix"
#' }
#'
#' @method plot EM
#' @export
plot.EM <- function(x, size = 1.25, link = c("identity", "logit", "log"), ...) {
  pl <- list()
  if (!is_mixmv(x)) {
    pl$mixdist <- plot.mix(x, size = size, ...)
  }
  ## in verbose mode we output EM fit diagnostics
  if (getOption("RBesT.verbose", FALSE)) {
    ## these NULL assignments make R check happy
    a <- b <- w <- s <- comp <- iteration <- NULL
    Nc <- ncol(x)
    pseq <- lapply(
      attr(x, "traceMix"),
      function(m) {
        class(m) <- "matrix"
        m <- as.data.frame(t(m))
        m$comp <- 1:Nc
        m
      }
    )
    names(pseq) <- 1:length(pseq) - 1
    Mw <- dplyr::bind_rows(pseq, .id = "iteration")
    Mw <- Mw[c(1, 5, 2, 3, 4)]
    Mw$iteration <- as.numeric(Mw$iteration)
    if ("EMbmm" %in% class(x)) {
      Mw <- within(Mw, {
        m <- a / (a + b)
        N <- a + b
        Lm <- logit(m)
        lN <- log(N)
      })
      if (Nc != 1) {
        Mw <- within(Mw, {
          Lw <- logit(w)
        })
      }
    }
    if ("EMnmm" %in% class(x)) {
      Mw <- within(Mw, {
        ls <- log(s)
      })
      if (Nc != 1) {
        Mw <- within(Mw, {
          Lw <- logit(w)
        })
      }
    }
    if ("EMgmm" %in% class(x)) {
      Mw <- within(Mw, {
        la <- log(a)
        lb <- log(b)
      })
      if (Nc != 1) {
        Mw <- within(Mw, {
          Lw <- logit(w)
        })
      }
    }
    pars <- names(Mw)[-c(1, 2)]
    Mw <- within(Mw, {
      Comp <- factor(comp)
    })
    LL <- data.frame(iteration = 0:max(Mw$iteration), lli = attr(x, "traceLli"))
    basePl <- ggplot(Mw, aes(x = .data$iteration, colour = .data$Comp)) +
      geom_line(size = size)
    for (p in pars) {
      pl[[p]] <- basePl + aes(y = .data$p)
    }
    pl$lli <- ggplot(
      subset(LL, iteration > 0),
      aes(x = .data$iteration, y = .data$lli)
    ) +
      geom_line(size = size) +
      ylab("log-likelihood")
  }
  ## pl$mix <- plot.mix(x, comp=TRUE, samp=attr(x, "x"), ...)
  link <- match.arg(link)
  dlink(x) <- link_map[[link]]

  cols <- bayesplot::color_scheme_get(i = 1:6)

  if (!is_mixmv(x)) {
    ## univariate case
    samp <- data.frame(Sample = mixlink(x, as.vector(attr(x, "x"))))
    ## workaround a weird bug in ggplot which enlarges the interval
    interval <- quantile(samp$Sample, c(0.025, 0.975))
    n_fun <- 501
    max_span <- diff(range(samp))
    interval_span <- diff(interval)
    n_fun <- min(5E3, round(n_fun * max_span / interval_span))

    if (!is.dlink_identity(dlink(x))) {
      subtitle <- paste("Link:", dlink(x)$name)
    } else {
      subtitle <- NULL
    }

    pl$mixdens <- bayesplot::mcmc_dens(samp) +
      bayesplot::facet_text(FALSE) +
      stat_function(
        inherit.aes = FALSE,
        fun = dmix,
        args = list(mix = x),
        size = size,
        n = n_fun
      ) +
      ggtitle(
        "Parametric Mixture (black line) and Kernel Estimate of Sample Density",
        subtitle = subtitle
      ) +
      bayesplot::xaxis_title(FALSE)

    pl$mixecdf <- ggplot(samp, aes(x = .data$Sample)) +
      stat_ecdf(geom = "area", size = 0, fill = cols$light) +
      stat_ecdf(geom = "step", size = size, colour = cols$mid) +
      stat_function(fun = pmix, args = list(mix = x), size = size, n = n_fun) +
      ggtitle(
        "Estimated Cumulative Density from Parametric Mixture (black line) and Sample",
        subtitle = subtitle
      ) +
      bayesplot::bayesplot_theme_get() +
      bayesplot::yaxis_title(FALSE) +
      bayesplot::xaxis_title(FALSE) +
      bayesplot::facet_text(FALSE)

    pl$mix <- bayesplot::mcmc_hist(
      samp,
      binwidth = diff(interval) / 50,
      freq = FALSE
    ) +
      bayesplot::facet_text(FALSE) +
      stat_function(
        inherit.aes = FALSE,
        fun = dmix,
        args = list(mix = x),
        size = size,
        n = n_fun
      ) +
      ggtitle(
        "Parametric Mixture Density (black line) and Histogram of Sample",
        subtitle = subtitle
      ) +
      bayesplot::xaxis_title(FALSE)
  } else if (inherits(x, "mvnormMix")) {
    var1 <- var2 <- NULL
    ## multivariate case: only support mvnorm for now (the only
    ## one supported as of Aug 2023). Plot the pair-wise marginal
    ## densities, which are the pair-wise marginal mixtures.
    message(
      "Diagnostic plots for mixture multivariate normal densities are experimental.\nPlease note that these are subject to changes in future releases."
    )
    samp <- attr(x, "x")
    p <- ncol(samp)
    dim_labels <- mvnorm_dim_labels(x[-1, 1])
    if (is.null(dim_labels)) {
      dim_labels <- paste0("Dimension ", 1:p)
    }
    breaks <- 60
    nbins <- 20
    pairs <- subset(expand.grid(var1 = 1:p, var2 = 1:p), var1 >= var2)
    layout <- matrix(NA, nrow = p, ncol = p)
    pl_pairs <- list()
    pl_pairs_compact <- list()
    for (i in seq_len(nrow(pairs))) {
      v1 <- pairs$var1[i]
      v2 <- pairs$var2[i]
      label <- paste0("mixpair[", v2, ",", v1, "]")
      mix_sub <- mvnorm_extract_dim(x, unique(c(v2, v1)))
      layout[v1, v2] <- i
      if (v1 == v2) {
        interval <- quantile(samp[, v1], c(0.025, 0.975))
        pl_pairs[[label]] <- bayesplot::mcmc_hist(
          samp[, v1, drop = FALSE],
          binwidth = diff(interval) / 50,
          freq = FALSE
        ) +
          bayesplot::facet_text(FALSE) +
          stat_function(
            inherit.aes = FALSE,
            fun = function(mix, x) dmix(mix, matrix(x, nrow = length(x))),
            args = list(mix = mix_sub)
          ) +
          ylab(dim_labels[v2]) +
          xlab(dim_labels[v2])
      } else {
        data_ranges <- apply(samp[, c(v2, v1), drop = FALSE], 2, range)
        colnames(data_ranges) <- c("x", "y")
        data_grid <- expand.grid(apply(
          data_ranges,
          2,
          function(r) seq(r[1], r[2], length = breaks),
          simplify = FALSE
        ))
        data_grid$z <- dmix(mix_sub, as.matrix(data_grid), log = TRUE)
        pl_pairs[[label]] <- bayesplot::mcmc_scatter(
          samp[, c(v2, v1), drop = FALSE],
          alpha = 0.1
        ) +
          bayesplot::facet_text(FALSE) +
          geom_contour(
            aes(z = .data$z),
            data = data_grid,
            bins = nbins,
            colour = "black"
          ) +
          xlab(dim_labels[v2]) +
          ylab(dim_labels[v1])
      }
      pl_pairs_compact[[label]] <- pl_pairs[[label]]
      if (v1 != p) {
        pl_pairs_compact[[label]] <- pl_pairs_compact[[label]] +
          bayesplot::xaxis_title(FALSE) +
          bayesplot::xaxis_ticks(FALSE) +
          bayesplot::xaxis_text(FALSE)
      }
      if (v2 != 1) {
        pl_pairs_compact[[label]] <- pl_pairs_compact[[label]] +
          bayesplot::yaxis_title(FALSE) +
          bayesplot::yaxis_ticks(FALSE) +
          bayesplot::yaxis_text(FALSE)
      }
    }
    pl$mixpairs <- bayesplot::bayesplot_grid(
      plots = pl_pairs_compact,
      grid_args = list(nrow = p, ncol = p, layout_matrix = layout)
    )
    pl$mixpairs$bayesplots <- pl_pairs
  }

  pl
}

Try the RBesT package in your browser

Any scripts or data that you put into this service are public.

RBesT documentation built on June 8, 2025, 10:05 a.m.