R/stanemax-methods.R

Defines functions create_exposure_range plot.stanemaxbin plot.stanemax extract_obs_mod_frame extract_stanfit replace_prm_names print.stanemaxbin print.stanemax

Documented in extract_obs_mod_frame extract_stanfit plot.stanemax plot.stanemaxbin print.stanemax print.stanemaxbin

#' Methods for stanemax objects
#'
#' @name stanemax-methods
#'
NULL


#' @rdname stanemax-methods
#' @param x An object of class `stanemax` or `stanemaxbin`
#' @param digits_summary The number of significant digits to use when printing
#' the summary, defaulting to 2. Applies to the quantities other than the
#' effective sample size, which is always rounded to the nearest integer.
#' @export
print.stanemax <- function(x, digits_summary = 2, ...) {
  param.extract <- c("emax", "e0", "ec50", "gamma", "sigma")

  s <- replace_prm_names(x, param.extract)

  cat("---- Emax model fit with rstanemax ----\n")
  cat("\n")
  print(round(s, digits_summary))
  cat("\n")
  cat("* Use `extract_stanfit()` function to extract raw stanfit object\n")
  cat("* Use `extract_param()` function to extract posterior draws of key parameters\n")
  cat("* Use `plot()` function to visualize model fit\n")
  cat("* Use `posterior_predict()` or `posterior_predict_quantile()` function to get\n")
  cat("  raw predictions or make predictions on new data\n")
  cat("* Use `extract_obs_mod_frame()` function to extract raw data \n")
  cat("  in a processed format (useful for plotting)\n")
}


#' @rdname stanemax-methods
#' @export
print.stanemaxbin <- function(x, digits_summary = 2, ...) {
  param.extract <- c("emax", "e0", "ec50", "gamma")

  s <- replace_prm_names(x, param.extract)

  cat("---- Binary Emax model fit with rstanemax ----\n")
  cat("\n")
  print(round(s, digits_summary))
  cat("\n")
  cat("* Use `extract_stanfit()` function to extract raw stanfit object\n")
  cat("* Use `extract_param()` function to extract posterior draws of key parameters\n")
  cat("* Use `plot()` function to visualize model fit\n")
  cat("* Use `extract_obs_mod_frame()` function to extract raw data \n")
  cat("  in a processed format (useful for plotting)\n")
}


replace_prm_names <- function(x, pars) {
  s <- rstan::summary(x$stanfit, pars = pars)$summary

  stbl <- dplyr::as_tibble(s, rownames = "prmname")

  ## Get parameter names in stanfit
  prm.df <-
    stbl %>%
    tidyr::separate(prmname, into = c("prm", "index", "del"), fill = "right", remove = FALSE) %>%
    dplyr::mutate(index = as.numeric(index)) %>%
    dplyr::select(prmname, prm, index)


  ## Get factor levels in covariates
  cov.levels <- x$prminput$cov.levels

  prm.cov.df.list <- list()

  for (k in names(cov.levels)) {
    if (length(cov.levels[[k]]) > 1) {
      prm.cov.df.list[[k]] <-
        dplyr::tibble(
          prm = k,
          level = cov.levels[[k]]
        ) %>%
        dplyr::mutate(
          index = dplyr::row_number(),
          level = as.character(level)
        )
    } else {
      prm.cov.df.list[[k]] <-
        dplyr::tibble(prm = k, index = 1, level = NA_character_)
    }
  }
  prm.cov.df <-
    dplyr::bind_rows(prm.cov.df.list)

  ## Merge back to the original data frame
  stbl2 <-
    dplyr::left_join(prm.df, prm.cov.df, by = c("prm", "index")) %>%
    dplyr::mutate(prmname2 = ifelse(is.na(level), prm, paste0(prm, "[", level, "]"))) %>%
    dplyr::left_join(stbl, ., by = "prmname") %>%
    dplyr::select(-c(prmname, prm, index, level))

  s2 <- as.matrix(dplyr::select(stbl2, -prmname2))
  rownames(s2) <- stbl2$prmname2

  return(s2)
}



#' @rdname stanemax-methods
#' @export
extract_stanfit <- function(x) {
  return(x$stanfit)
}

#' @rdname stanemax-methods
#' @export
extract_obs_mod_frame <- function(x) {
  return(x$prminput$df.model)
}


#' @rdname stanemax-methods
#' @export
#' @param show.ci An logical specifying if the output figure include
#' credible interval of posterior prediction. Default TRUE.
#' @param show.pi An logical specifying if the output figure include
#' prediction interval. Default FALSE.
#' @param ci Credible interval range.
#' @param pi Prediction interval range.
#' @param ... Additional arguments passed to methods.
plot.stanemax <- function(x, show.ci = TRUE, show.pi = FALSE,
                          ci = 0.9, pi = 0.9, ...) {
  obs <- x$prminput$df.model

  exposure.range <- create_exposure_range(x$standata)

  cov.list <-
    obs %>%
    dplyr::select(covemax, covec50, cove0) %>%
    dplyr::distinct()

  cov.list.2 <- create_cov_groups(cov.list, x$prminput$param.cov)

  newdata <- tidyr::crossing(exposure.range, cov.list.2)

  pred.quantile <- posterior_predict_quantile(x,
    newdata = newdata, newDataType = "modelframe",
    ci = ci, pi = pi
  )


  if (is.null(x$prminput$param.cov)) {
    g <- ggplot2::ggplot(pred.quantile, ggplot2::aes(exposure, ci_med))
  } else {
    g <- ggplot2::ggplot(pred.quantile, ggplot2::aes(exposure, ci_med,
      group = Covariates,
      color = Covariates,
      fill = Covariates
    ))
  }

  g <-
    g +
    ggplot2::geom_point(data = obs, ggplot2::aes(y = response)) +
    ggplot2::labs(y = "response")

  if (show.ci) g <- g + ggplot2::geom_ribbon(ggplot2::aes(ymin = ci_low, ymax = ci_high), alpha = .5, color = NA)
  if (show.pi) g <- g + ggplot2::geom_ribbon(ggplot2::aes(ymin = pi_low, ymax = pi_high), alpha = .2, color = NA)

  g <- g + ggplot2::geom_line()


  g
}

#' @rdname stanemax-methods
#' @export
plot.stanemaxbin <- function(x, show.ci = TRUE, show.pi = FALSE,
                             ci = 0.9, pi = 0.9, ...) {
  obs <- x$prminput$df.model

  exposure.range <- create_exposure_range(x$standata)

  cov.list <-
    obs %>%
    dplyr::select(covemax, covec50, cove0) %>%
    dplyr::distinct()

  cov.list.2 <- create_cov_groups(cov.list, x$prminput$param.cov)

  newdata <- tidyr::crossing(exposure.range, cov.list.2)

  pred.quantile <- posterior_predict_quantile(
    x,
    newdata = newdata, newDataType = "modelframe", ci = ci
  )

  if (is.null(x$prminput$param.cov)) {
    g <- ggplot2::ggplot(pred.quantile, ggplot2::aes(exposure, ci_med))
  } else {
    g <- ggplot2::ggplot(pred.quantile, ggplot2::aes(exposure, ci_med,
      group = Covariates,
      color = Covariates,
      fill = Covariates
    ))
  }

  g <-
    g +
    ggplot2::geom_jitter(
      data = obs, ggplot2::aes(y = response),
      width = 0, height = 0.05, alpha = 0.5
    ) +
    ggplot2::coord_cartesian(ylim = c(-0.05, 1.05)) +
    ggplot2::labs(y = "response")

  if (show.ci) g <- g + ggplot2::geom_ribbon(ggplot2::aes(ymin = ci_low, ymax = ci_high), alpha = .5, color = NA)

  g <- g + ggplot2::geom_line()

  g
}


create_exposure_range <- function(standata, length.out = 50) {
  min.newdata <- min(standata$exposure)
  min.nozero.newdata <- min(standata$exposure[standata$exposure > 0])
  max.newdata <- max(standata$exposure)

  seq.normal.scale <- seq(min.newdata, max.newdata, length.out = length.out)
  seq.log.scale <- exp(seq(log(min.nozero.newdata),
    log(max.newdata),
    length.out = length.out
  ))

  exposure.range <- dplyr::tibble(exposure = sort(c(seq.normal.scale, seq.log.scale)))

  return(exposure.range)
}

Try the rstanemax package in your browser

Any scripts or data that you put into this service are public.

rstanemax documentation built on April 4, 2025, 2:39 a.m.