R/plotQVI.R

Defines functions plotQVI

Documented in plotQVI

#' @title plot variable importance comparison by quantile
#' @description
#' Computes the quantile ALE-induced variable importance (VI) measure for each of the
#' covariate specified in var.index, and produces a ranking plot of the covariates using
#' bar plot for each quantile of interest.
#'
#' @name plotQVI
#'
#' @param object An object of class \code{SPQR}.
#' @param var.index A vector specifying the index of the covariates for which VI measures should be computed.
#'   Default is \code{NULL} indicating all covariates are considered.
#' @param var.names The names of the covariates to appear in the bar plots. Default is \code{NULL} and the
#'   the function will use generic names generated by \code{parse(text=paste0("X[",var.index,"]"))}.
#' @inheritDotParams QALE -object -var.index -getAll
#'
#' @return A \code{ggplot} object.
#'
#' @import ggplot2
#'
#' @examples
#' \donttest{
#' set.seed(919)
#' n <- 200
#' X <- matrix(runif(n*2, 0, 2), nrow = n, ncol = 2)
#' Y <- rnorm(n, X[,1]^2, 0.3+X[,1]/2)
#' control <- list(iter = 200, warmup = 150, thin = 1)
#' fit <- SPQR(X=X, Y=Y, n.knots=12, n.hidden=5, method="MCMC",
#'             control=control, normalize=TRUE, verbose = FALSE)
#'
#' ## compute quantile VI of at tau = 0.2,0.5,0.8
#' plotQVI(fit, tau=c(0.2,0.5,0.8))
#' }
#'
#' @export
plotQVI <- function(object, var.index = NULL, var.names = NULL, ...) {

  dotparams <- list(...)
  tau <- dotparams$tau
  ci.level <- dotparams$ci.level
  if (is.null(tau)) tau <- seq(0.1,0.9,0.1)
  if (is.null(ci.level)) ci.level <- 0

  if (is.null(var.index)) var.index <- 1:ncol(object$X)
  else stopifnot(length(var.index)>1)
  if (!is.null(var.names) && length(var.names) != length(var.index))
    stop("`var.names` must have the same length as 'var.index'.")

  if (object$method != "MCMC") ci.level <- 0
  x.ticks <- {}
  if (!is.null(var.names)) x.ticks <- var.names
  else x.ticks <- paste0("X[",var.index,"]")
  names(x.ticks) <- var.index
  tauexp <- factor(tau, levels=tau, labels=paste0("tau==",tau))
  if (ci.level == 0) {
    vi <- matrix(nrow=length(var.index),ncol=length(tau))
    for (i in 1:nrow(vi)) {
      ale <- QALE(object, var.index=var.index[i], ...)
      if (length(ale$x)>5) vi[i,] <- apply(ale$ALE,2,stats::sd)
      else vi[i,] <- apply(ale$ALE,2,function(x) max(x)-min(x))/4
    }
    .df <- data.frame(x=x.ticks)
    df <- do.call(rbind,lapply(seq_along(tau), FUN = function(i) {
      .df$y <- vi[,i]
      #x.ticks <- x.ticks[order(df$y, decreasing = T)]
      .df <- .df[order(.df$y, decreasing = T),]
      .df$tauexp <- tauexp[i]
      return(.df)
    }))
    p <-
      ggplot(data=df, aes(x=.reorder_within(.data$x,-.data$y,.data$tauexp), y=.data$y)) +
      geom_bar(stat="identity",fill="#999999")
  } else {
    nnn <- length(object$model)
    .vi <- array(dim=c(length(var.index),length(tau),nnn))
    vi <- array(dim=c(length(var.index),length(tau),3))
    for (i in 1:nrow(vi)) {
      ale <- QALE(object, var.index=var.index[i], getAll=TRUE, ...)
      if (length(ale$x)>5) .vi[i,,] <- apply(ale$ALE,c(2,3),stats::sd)
      else .vi[i,,] <- apply(ale$ALE,c(2,3),function(x) max(x)-min(x))/4
      vi[i,,1] <- apply(.vi[i,,],1,stats::quantile,probs=(1-ci.level)/2)
      vi[i,,2] <- apply(.vi[i,,],1,mean)
      vi[i,,3] <- apply(.vi[i,,],1,stats::quantile,probs=(1+ci.level)/2)
    }
    .df <- data.frame(x=x.ticks)
    df <- do.call(rbind,lapply(seq_along(tau), FUN = function(i) {
      .df$y <- as.vector(vi[,i,2])
      .df$ymin <- as.vector(vi[,i,1])
      .df$ymax <- as.vector(vi[,i,3])
      #x.ticks <- x.ticks[order(df$y, decreasing = T)]
      .df <- .df[order(.df$y, decreasing = T),]
      .df$tauexp <- tauexp[i]
      return(.df)
    }))
    p <-
      ggplot(data=df, aes(x=.reorder_within(.data$x,-.data$y,.data$tauexp), y=.data$y)) +
      geom_bar(stat="identity",fill="#999999") +
      geom_errorbar(aes(ymin=.data$ymin,ymax=.data$ymax),color="#000000")
  }
  p <- p +
    theme_bw() +
    scale_x_discrete(labels = function(x, sep = "___") {
      reg <- paste0(sep, ".+$")
      parse(text=gsub(reg, "", x))
    }) +
    facet_wrap(~tauexp, labeller=label_parsed, scales='free_x') +
    labs(x=NULL, y="Importance") +
    theme(panel.grid.major=element_blank(),
          panel.grid.minor=element_blank(),
          panel.spacing=unit(0, "lines"),
          axis.title=element_text(size = 15),
          axis.text=element_text(colour="black", size = 12),
          strip.text=element_text(size = 15))
  return(p)
}

.reorder_within <- function (x, by, within, fun = mean, sep = "___", ...) {
  if (!is.list(within)) within <- list(within)
  new_x <- do.call(paste, c(list(x, sep = sep), within))
  stats::reorder(new_x, by, FUN = fun)
}

Try the SPQR package in your browser

Any scripts or data that you put into this service are public.

SPQR documentation built on May 3, 2022, 1:08 a.m.