R/PMXStanFit.R

Defines functions colVars plot.pred.rsd plot.pred.obs idv.obs.pred.vs.time.PMXStanFit idv.obs.pred.vs.time rsd.vs.pred.PMXStanFit rsd.vs.pred obs.vs.pred.PMXStanFit obs.vs.pred gofplot.PMXStanFit gofplot waic.PMXStanFit waic traces.PMXStanFit traces print.PMXStanFit PMXStanFit

Documented in gofplot idv.obs.pred.vs.time obs.vs.pred PMXStanFit rsd.vs.pred traces waic

#' @title Generation of a \code{PMXStanFit} object
#'
#' @description
#' Reads in data, runs a compiled Stan executable, generates posterior samples for model parameters, 
#' check convergence, and perform model diagnosis.
#' 
#' @details
#' Intuitively, a \code{PMXStanFit} object can be uniquely realized by a \code{PMXStanModel} object that has
#' been compiled successfully, an input list of data that has been prepared compatible to the model object, and a
#' group of arguments passed to \code{Stan} to run \code{\link[rstan]{sampling}}. 
#' 
#' In addition to serving as an interface to generate samples with \code{Stan}, \code{PMXStanFit} also provides an
#' interface for users to perform a variety of post-processing procedures by inquiring these samples,  making 
#' predictions, and comparing with observations. Basic functions are available to investigate sampling behavior 
#' and check convergence, return diagnostic statistics for a fitted Bayesian model, and generate goodness-of-fit
#' plots commonly used for PK/PD models, both for the overall population and for subgroups by covariates.
#' 
#' Since all samples from all chains are conveniently accessible, users can also easily boost the capability of
#' a \code{PMXStanFit} object by writing their own diagnostic/goodness-of-fit/visual predictive checks functions.
#'
#' @param model a \code{\link{PMXStanModel}} object that needs to have a Stan executable (already compiled from 
#'              the model-specific Stan code) ready for sampling.
#' @param dat a named list that provides the input data for the Stan model and other information such as individual 
#'            \code{ID}'s and relevant covariates (specified by users).
#'            Usually generated by \code{\link{prepareInputData}} or modified from the output of the same function.
#' @param ... any other arguments that are passed to the function \code{\link[rstan]{sampling}}, such as \code{chains},
#'            number of chains to run(default is 4); \code{iter}, number of iterations (default is 2000); and 
#'            \code{thin}, the period for saving samples (default is 1); etc.
#'
#' @return
#' A \code{PMXStanFit} object, with the following list of methods:
#' \item{get.fit}{returns the output derived from fitting a Stan model, including the samples; the same output as 
#'                defined by \code{\link[rstan]{stanfit-class}}.
#'               }
#' \item{print.fit}{prints out statistics of posterior samples, with the following arguments:
#' 
#'     \code{on.screen}: a logical variable that controls whether or not to print results on screen. Default is \code{TRUE}.
#'  
#'     \code{save.mode}: a logical variable that controls whether or not to export results to a text file. Default is \code{TRUE}.
#'  
#'     \code{...}: any other arguments that are passed to the more generic function \code{\link[rstan]{print.stanfit}} from 
#'                 \emph{rstan}, such as \code{digits_summary}, number of significant digits for printing out the summary; 
#'                 \code{pars}, parameters in which the summaries are interest;
#'                 \code{probs}, quantiles of interest for summary statistics; etc.
#'                  
#'     This function has a generic form as well (see \emph{Examples}).
#' }
#' \item{get.path}{returns path of the folder that stores all post-processing results, such as printed statistics for 
#'                 samples and goodness-of-fit plots.
#'                }
#' \item{get.waic}{returns diagnostic statistics for a fitted Bayesian model: \emph{Watanabe-Akaike information criterion} 
#'                 (\code{WAIC}) and \emph{Leave-on-out cross-validation} (\code{LOO-CV}).
#'                 The argument \code{complete} allows to select whether to input all pointwise and total statistics 
#'                 (\code{TRUE}) or only total statistics (\code{FALSE}, by default). For more details, see \emph{References}.
#'                 This function has a generic form as well (see \emph{Examples}).
#'                }
#' \item{plot.trace}{plots traces and posterior distributions corresponding to one or more Markov chains, therefore to 
#'                   investigate sampling behavior and to assess mixing across chains and convergence. The argument 
#'                   \code{pars} specifies names of parameters whose traces will be plotted. When it is not specified, 
#'                   the function by default plots all the \code{theta}'s (model parameters), \code{sigma_eta}'s 
#'                   (variance of the inter-individual random effects), and \code{sigma} (variance of the intra-individual
#'                   random effects) in the auto-generated Stan code.
#'                   This function has a generic form as well (see \emph{Examples}).
#'                   For a similar implementation, see \code{\link[rstan]{traceplot}}.
#'                  }
#' \item{plot.gof.pred.obs}{plots medians of predictions vs. observations for goodness-of-fit assessment, with 
#'                          the following arguments:
#'
#'                          \code{by.cov}: a string to specify the covariate name under investigation. If left as
#'                                         \code{NULL} by default, the plot will be generated based on the overall
#'                                         population from all individuals in the input data. To make sure that the
#'                                         specified covariate name be properly recognized, please also specify it
#'                                         in the argument \code{covar} when calling \code{\link{prepareInputData}}.                   
#'
#'                          \code{type}: a string to specify the type of the parameter, can be "categorical" (abbreviated
#'                                       as "cat") or "continuous" (abbreviated as "con"). This argument is ignored
#'                                       (automatically set to \code{NULL}) if \emph{by.cov} is not provided by user.
#'
#'                          \code{cutoff}: a number or vector of numbers to specify the cut-off points by which the
#'                                         subgroups based on a continuous covariate are defined. If a parameter name
#'                                         is specified in \emph{by.cov} and the type is specified as "continuous" in
#'                                         \emph{type}, the default cut-off (when not provided by user) is set as the 
#'                                         median of the corresponding parameter across the population. This argument
#'                                         is ignored for a categorical covariate, where the subgroups will be 
#'                                         automatically determined by all available distinct values of the covariate.
#'
#'                          \code{filename}: a string to specify the path to store the plots. If not provided by user,
#'                                           the plots will be output to screen automatically.
#'
#'                          Noted that a good fit usually results in a group of points clustered around the line of 
#'                          \code{x=y} representing alignment between observations and predictions. This function has 
#'                          a generic form as well (see \emph{Examples}).
#'                         }
#' \item{plot.gof.pred.rsd}{plots residuals, calculated as the difference between medians of predictions and observations, 
#'                          vs. predictions as one way of goodness-of-fit assessment, with the same arguments as 
#'                          \code{plot.gof.pred.rsd()} above. A good fit usually results in a 
#'                          group of points clustered around the horizontal line of \code{y=0}, without obvious trend of 
#'                          deviation. This function has a generic form as well (see \emph{Examples}).
#'                         }
#' \item{plot.gof.idv}{plots time profiles of the prediction medians and 95% predictive intervals on top of the time 
#'                     profile of observations. This function has a generic form as well (see \emph{Examples}).
#'                    }
#'
#' @seealso
#' \code{\link{PMXStanModel}} for initialization and compilation of a \code{PMXStanModel} object;
#' \code{\link{prepareInputData}} for transformation of a NONMEM-readable dataset to a list compatiable to auto-generated 
#' model-specific Stan code;
#' \code{\link[rstan]{sampling}} for usage of arguments to draw samples from a compiled Stan model.
#'
#' @references
#' The Stan Development Team. \emph{Stan Modeling Language User's Guide and Reference Manual}. \url{http://mc-stan.org}. 
#'
#' Aki Vehtari and Andrew Gelman. \emph{WAIC and cross-validation in Stan}. 
#' \url{http://www.stat.columbia.edu/~gelman/research/unpublished/waic_stan.pdf}.
#' 
#' Sumio Watanabe. \emph{Asymptotic Equivalence of Bayes Cross Validation and Widely Applicable Information Criterion
#' in Singular Learning Theory}. \url{http://www.jmlr.org/papers/volume11/watanabe10a/watanabe10a.pdf}.
#'
#' @author Yuan Xiong and Wenping Wang
#'
#' @examples
#' \dontrun{
#' ### A population PK model
#' m1 <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#' print(m1)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, 
#'                         model = m1,
#'                         covar = c("AGE","GENDER")
#'                        )
#' fit <- PMXStanFit(m1, dat, iter=100, chains=1)
#' print(fit, on.screen=F)
#' save(m1, dat, fit, file = file.path("pk_m1","ModelFit.RData"))
#'
#' traces(fit)
#' waic(fit)
#' gofplot(fit)
#' 
#' obs.vs.pred(fit, by.cov = "AGE", type = "continuous", 
#'             cutoff = c(50, 60), filename = "obs_pred_by_age.pdf"
#'            )
#' obs.vs.pred(fit, by.cov = "GENDER", type = "categorical", 
#'             filename = "obs_pred_by_gender.pdf"
#'            )
#' rsd.vs.pred(fit, by.cov = "AGE", type = "continuous", 
#'             cutoff = c(50, 60), filename = "rsd_pred_by_age.pdf"
#'            )
#' rsd.vs.pred(fit, by.cov = "GENDER", type = "categorical", 
#'             filename = "rsd_pred_by_gender.pdf"
#'            )
#'
#' ### A population PKPD model
#' ode <- "
#'   C2 = centr/V;
#'   d/dt(depot) =-ka*depot;
#'   d/dt(centr) = ka*depot - ke*centr;
#'   d/dt(eff) = (1+Emax*C2/(C2+EC50))*Kin - Kout*eff;
#' "
#' instant.stan.extension(ode)
#'
#' m2 <- PMXStanModel(type = "PKPD", 
#'                    path = "pkpd_m2",
#'                    ode = ode, 
#'                    theta= c("Emax","EC50"),
#'                    eta = c("Emax","EC50"),
#'                    const = c(V=1, ka=0.5, ke=0.4, Kin=0.5, Kout=0.5),
#'                    obs.state = 3
#'                   )
#' compile(m2)
#'
#' dat2 <- prepareInputData(data.source = d2_nm_poppkpd, 
#'                          model = m2,
#'                          inits = "BSL",
#'                          covar = c("BMK1","BMK2")
#'                         )
#' 
#' fit2 <- PMXStanFit(m2, dat2, iter = 100, chains = 2)
#' print(fit2, on.screen = F)
#' 
#' save(m2, dat2, fit2, file = file.path("pkpd_m2", "ModelFit.RData"))
#' 
#' traces(fit2)
#' waic(fit2)
#' gofplot(fit2)
#'
#' obs.vs.pred(fit2, by.cov = "BMK1", type = "continuous", 
#'             cutoff = 1, filename = "obs_pred_BMK1.pdf"
#'            )
#' obs.vs.pred(fit2, by.cov = "BMK2", type = "categorical", 
#'             filename = "obs_pred_BMK2.pdf"
#'            )
#' rsd.vs.pred(fit2, by.cov = "BMK1", type = "continuous", 
#'             cutoff = 1, filename = "rsd_pred_BMK1.pdf"
#'            )
#' rsd.vs.pred(fit2, by.cov = "BMK2", type = "categorical", 
#'             filename = "rsd_pred_BMK2.pdf"
#'            )
#' }
#
# Documentation last updated: 6Apr2016
#
# Code Last updated: 4Apr2016 by Yuan Xiong
# revised 30Mar2016
#   - renamed and improved function for trace plotting
# revised 10Mar2016
#   - added calculation of WAIC and LOO-CV
# revised 9Mar2016
#   - allowed for diagnostics by covariates
# revised 21Jan2016
#   - implemented default scenario for traceplot
# Created 28Oct2015 by Yuan Xiong

utils::globalVariables(c("ID", "GRP"))
PMXStanFit <- function(model, dat, ...)
{
  .stanmodel <- NULL
  
  # check input arguments
  if(!hasArg(model)){ 
    stop("A PMXStanModel object must be provided.")
  } else {
    .stanmodel <- model$retrieve.stanmodel()
    if(is.null(.stanmodel))
      stop("A compiled Stan model must be provided. \n  Please compile the model defined by PMXStanModel object using compile() function.")
  }
  if(!hasArg(dat)) 
    stop("Input data must be provided.")
  
  # model fit initialization
  .fit <- NULL
  
  ######################################################
  # Sampling
  .standata <- dat$standata  
  .fit <- sampling(.stanmodel, data=.standata, ...)
  
  ######################################################
  # Get relevant information from model specs
  mspecs <- model$get.model.specs()
  .fitpath <- file.path(mspecs$path,"fit")
  .ntheta <- model$get.ntheta()
  .neta <- model$get.neta()
  
  # check/create model path
  if (!file.exists(.fitpath)) {
      dir.create(.fitpath)
  }
  
  ######################################################
  # Print fit statistics
  print.fit <- function(...) 
  {   
    print(.fit, ...)
    # print on screen
    #if(on.screen) print(.fit, ...)
    
    # save results in a text file
    #if(save.mode) capture.output(print(.fit), file = file.path(.fitpath, "summary.txt"))
  }
  
  ######################################################
  # Trace plots
  plot.trace <- function(pars=NULL) 
  {  
    fit.coda <- as.mcmc.list(.fit)
    sel.param <- NULL
    
    # get parameter names
    parnames <- colnames(fit.coda[[1]])
    
    if(is.null(pars)) { # default: plot theta's, sigma_eta's, and sigma 
      idx.theta <- sapply(seq(.ntheta), function(idx){
                    par.name = paste("theta[", idx, "]", sep="")  
                    match(par.name, parnames)
                  })
      idx.eta <- sapply(seq(.neta), function(idx){
                  par.name = paste("sigma_eta[", idx, "]", sep="")  
                  match(par.name, parnames)
                })
      idx.sigma <- match("sigma", parnames)
      sel.param <- c(idx.theta, idx.eta, idx.sigma)
    } else{ # plot user-specified parameters if present
      for(idx in 1:length(pars))
      {
        par.name <- pars[idx]
        if(is.na(match(par.name, parnames))){
          warning(paste("Parameter ", par.name, " was not found. Please double check the name.", sep=""))
        } else {
          sel.param <- cbind(sel.param, match(par.name, parnames))
        }
      }
    }
    
    if(is.null(sel.param)){
      stop("None of the user-specified parameter names can be found.")
    }else{
      plot(fit.coda[,sel.param])
    }
      
  }
  
  ######################################################
  # Diagnostics for a fitted Bayesian model
  # - Watanabe-Akaike Information Criterion (WAIC) [Watenabe 2010]
  # - Leave-one-out cross-validation (LOO)
  # modified from Vehtari and Gelman, 2014
  waic <- function(complete = FALSE) 
  {  
    log_lik <- extract(.fit, "log_lik")$log_lik
    
    dim(log_lik) <- if(is.null(dim(log_lik))) c(length(log_lik),1) else c(dim(log_lik)[1], prod(dim(log_lik)[2:length(dim(log_lik))]))    
    S <- nrow(log_lik)
    n <- ncol(log_lik)
    
    # log pointwise predictive density
    lpd <- log(colMeans(exp(log_lik)))
    
    # effective number of parameters
    p_waic <- colVars(log_lik)
    # elpd_waic 
    elpd_waic <- lpd - p_waic
    # Watanabe-Akaike information criterion, interpreted as a computationally convenient approximation to cross-validation
    waic <- -2*elpd_waic
    
    # importance weights [Gelfand, Dey, and Chang (1992)]
    loo_weights_raw <- 1/exp(log_lik - max(log_lik))
    # normalized weights 
    loo_weights_normalized <- loo_weights_raw/matrix(colMeans(loo_weights_raw), nrow=S, ncol=n, byrow=TRUE)
    # regularized/stabilized weights [Ionides (2008)]
    loo_weights_regularized <- pmin(loo_weights_normalized, sqrt(S))
    
    # stablized importance-sampling LOO estimate of the expected log pointwise predictive density
    elpd_loo <- log(colMeans(exp(log_lik)*loo_weights_regularized)/colMeans(loo_weights_regularized))
    # effective number of parameters for LOO
    p_loo <- lpd - elpd_loo
    
    # combine results 
    pointwise <- cbind(waic, lpd, p_waic, elpd_waic, p_loo, elpd_loo)
    total <- colSums(pointwise)
    se <- sqrt(n*colVars(pointwise))
    
    if (complete) { 
      return(list(waic = total["waic"], elpd_waic = total["elpd_waic"], p_waic = total["p_waic"],
                  elpd_loo = total["elpd_loo"], p_loo = total["p_loo"],
                  pointwise = pointwise, total = total, se = se
                 )
            )
    } else {
      return(list(total = total, se = se))
    }
    
  }
  
  ######################################################
  # GoF plots
  d.type <- dat$data.type
  m.type <- mspecs$type
  res.fit <- extract(.fit)
  
  # median and 95% predictive interval for individual predictions
  yp <- apply(res.fit$y_pred, 2, median)
  yhi <- apply(res.fit$y_pred, 2, quantile, probs = 0.975)
  ylo <- apply(res.fit$y_pred, 2, quantile, probs = 0.025)
  
  if (m.type == "PK") { # PK model
    m.pk.solver <- mspecs$solver
    yobs <- dat$standata$conc
    switch(m.pk.solver,
           "closed_form" = {
             time.plot <- dat$standata$obs_time
           },       
           "ODE" = {
             time.plot <- dat$standata$evt_time[dat$standata$evid==0]
           }
    )
  } else { # PKPD model
      yobs <- dat$standata$y
      time.plot <- dat$standata$evt_time[dat$standata$evid==0]     
  }
  
  ID.exp <- dat$ID
  if(d.type == "individual"){
      ID.exp <- c(dat$ID, dat$ID+1)
  }
  
  df.plot <- rbind(data.frame(ID=rep(ID.exp, dat$standata$NOBS), Time=time.plot, Y=yobs, GRP=1),
   	       	  data.frame(ID=rep(ID.exp, dat$standata$NOBS), Time=time.plot, Y=yp, GRP=2),
   	          data.frame(ID=rep(ID.exp, dat$standata$NOBS), Time=time.plot, Y=ylo, GRP=3),
                  data.frame(ID=rep(ID.exp, dat$standata$NOBS), Time=time.plot, Y=yhi, GRP=4)
                 )
  if(!is.null(dat$d.cov)){
      df.plot <- merge(df.plot, dat$d.cov, by = "ID")
  }
  
  # remove artificial data added to individual data
  if(d.type == "individual"){
    yp <- yp[1:(length(yp)-1)]
    yobs <- yobs[1:(length(yobs)-1)]
    df.plot <- subset(df.plot, ID == df.plot$ID[1])
  }
  
  # sort data by ID and TIME
  df.plot <- df.plot[order(df.plot$ID, df.plot$Time),]
  
  ### pred vs. obs
  plot.gof.pred.obs <- function(by.cov = NULL, type = NULL, cutoff = NULL, filename = NULL) 
  {    
    if(!hasArg(by.cov)){  # no covariate specified
      # plot the overall obs vs. pred
      plot.pred.obs(yobs = yobs, yp = yp, maintext = "Predicted vs. observed")   
    } else { # a covariate has been specified
      if(!hasArg(type)){ # no type of covariate has been specified
        stop("Type of covariate needs to be specified. Please select from \"continuous\" (or \"con\"), and \"categorical\" (or \"cat\").")
      } else { # type of covariate has been specified
          if(type %in% c("categorical", "cat")) { # categorical covariate
            cov.grp <- unique(dat$d.cov[, by.cov])
            if(hasArg(filename)){
              pdf(file.path(.fitpath, filename))
            }
            for (cov.val in cov.grp)
            {
              plot.pred.obs(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]==cov.val, ]$Y,
  	                    yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]==cov.val, ]$Y,
  	                    maintext = paste("Observed vs. predicted: ", by.cov, "=", cov.val, sep="")
                           )  
            }
            if(hasArg(filename)){ 
	      dev.off()
            }
          } else { # continuous covartiate
            if(!hasArg(cutoff)) { # no cuteoff for a continuous covariate has been specified
              message("A covariate was specified as continous, but the cutoff point(s) was not provided. It is set to the median by default.")
              cutoff <- median(dat$d.cov[, by.cov])
            } # if(!hasArg(cutoff))
            
            cutoff <- sort(cutoff)
            # check cutoff points
            if(min(cutoff) < min(dat$d.cov[, by.cov]))
              stop("At least one cutoff point is below the lower bound of the selected covariate.")
            if(max(cutoff) > max(dat$d.cov[, by.cov]))
              stop("At least one cutoff point is beyond the upper bound of the selected covariate.")
            
            n.cov.grp <- length(cutoff) + 1
            if(hasArg(filename)){ 
              pdf(file.path(.fitpath, filename))
            }
            # first group
            plot.pred.obs(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]<=cutoff[1], ]$Y,
  	                  yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]<=cutoff[1], ]$Y,
  	                  maintext = paste("Observed vs. predicted: ", by.cov, "<=", cutoff[1], sep="")
  	                 )
  	    if(n.cov.grp > 2) { # more than 1 cutoff (thus more than 2 groups) 
              for (cov.grp.idx in 2 : (n.cov.grp-1))
  	      {
  	        plot.pred.obs(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]>cutoff[cov.grp.idx-1] & df.plot[,by.cov]<=cutoff[cov.grp.idx] , ]$Y,
  	                      yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]>cutoff[cov.grp.idx-1] & df.plot[,by.cov]<=cutoff[cov.grp.idx] , ]$Y,
  	                      maintext = paste("Observed vs. predicted: ", cutoff[cov.grp.idx-1], "<", by.cov, "<=", cutoff[cov.grp.idx], sep="")
  	                     )  
              }
            }
            # last group
            plot.pred.obs(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]>cutoff[n.cov.grp-1], ]$Y,
  	    	          yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]>cutoff[n.cov.grp-1], ]$Y,
  	    	          maintext = paste("Observed vs. predicted: ", by.cov, ">", cutoff[n.cov.grp-1], sep="")
  	                 )
  	    if(hasArg(filename)){ 
	      dev.off()
            }      
          } # if(type %in% c("categorical", "cat"))
      }  # if(!hasArg(type))
    } # if(!hasArg(by.cov))          
  } # end of function plot.gof.pred.obs
  
  ### individual fitting
  plot.gof.idv <- function() 
  {
    if(d.type == "individual"){
      print(xyplot(Y ~ Time, df.plot, groups = GRP, type = c("p", "l","l","l"), 
                   lty = c(0,1,2,2), pch = c(1, NA, NA, NA), as.table = TRUE, 
                   lwd = c(0,2,1,1), col = c("blue","magenta","magenta","magenta"),
                   xlab = "Time (day)", ylab = "Observation and prediction",
                   main = "Individual fittings"
            ))
    } else { # population
      print(xyplot(Y ~ Time | format(ID), df.plot, groups = GRP, type = c("p", "l","l","l"), 
                   lty = c(0,1,2,2), pch = c(1, NA, NA, NA), as.table = TRUE, 
                   lwd = c(0,2,1,1), col = c("blue","magenta","magenta","magenta"),
                   layout = c(4,3),
                   xlab = "Time (day)", ylab = "Observation and prediction",
                   main = "Individual fittings"
            ))
    }
  }
  
  ### residual plot
  plot.gof.pred.rsd <- function(by.cov = NULL, type = NULL, cutoff = NULL, filename = NULL) 
  {    
    if(!hasArg(by.cov)){  # no covariate specified
      # plot the overall obs vs. pred
      plot.pred.rsd(yobs = yobs, yp = yp, maintext = "Residual vs. prediction")   
    } else { # a covariate has been specified
      if(!hasArg(type)){ # no type of covariate has been specified
        stop("Type of covariate needs to be specified. Please select from \"continuous\" (or \"con\"), and \"categorical\" (or \"cat\").")
      } else { # type of covariate has been specified
          if(type %in% c("categorical", "cat")) { # categorical covariate
            cov.grp <- unique(dat$d.cov[, by.cov])
            if(hasArg(filename)){
              pdf(file.path(.fitpath, filename))
            }
            for (cov.val in cov.grp)
            {
              plot.pred.rsd(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]==cov.val, ]$Y,
  	                    yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]==cov.val, ]$Y,
  	                    maintext = paste("Residual vs. prediction: ", by.cov, "=", cov.val, sep="")
                           )  
            }
            if(hasArg(filename)){ 
  	      dev.off()
            }
          } else { # continuous covartiate
            if(!hasArg(cutoff)) { # no cuteoff for a continuous covariate has been specified
              message("A covariate was specified as continuous, but the cutoff point(s) was not provided. It is set to the median by default.")
              cutoff <- median(dat$d.cov[, by.cov])
            } # if(!hasArg(cutoff))
            
            cutoff <- sort(cutoff)
            # check cutoff points
            if(min(cutoff) < min(dat$d.cov[, by.cov]))
              stop("At least one cutoff point is below the lower bound of the selected covariate.")
            if(max(cutoff) > max(dat$d.cov[, by.cov]))
              stop("At least one cutoff point is beyond the upper bound of the selected covariate.")
            
            n.cov.grp <- length(cutoff) + 1
            if(hasArg(filename)){ 
              pdf(file.path(.fitpath, filename))
            }
            # first group
            plot.pred.rsd(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]<=cutoff[1], ]$Y,
  	                  yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]<=cutoff[1], ]$Y,
  	                  maintext = paste("Residual vs. prediction: ", by.cov, "<=", cutoff[1], sep="")
  	                 )
  	    if(n.cov.grp > 2) { # more than 1 cutoff (thus more than 2 groups) 
              for (cov.grp.idx in 2 : (n.cov.grp-1))
  	      {
  	        plot.pred.rsd(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]>cutoff[cov.grp.idx-1] & df.plot[,by.cov]<=cutoff[cov.grp.idx] , ]$Y,
  	                      yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]>cutoff[cov.grp.idx-1] & df.plot[,by.cov]<=cutoff[cov.grp.idx] , ]$Y,
  	                      maintext = paste("Residual vs. prediction: ", cutoff[cov.grp.idx-1], "<", by.cov, "<=", cutoff[cov.grp.idx], sep="")
  	                     )  
              }
            }
            # last group
            plot.pred.rsd(yobs = df.plot[df.plot$GRP==1 & df.plot[,by.cov]>cutoff[n.cov.grp-1], ]$Y,
  	    	          yp = df.plot[df.plot$GRP==2 & df.plot[,by.cov]>cutoff[n.cov.grp-1], ]$Y,
  	    	          maintext = paste("Residual vs. prediction: ", by.cov, ">", cutoff[n.cov.grp-1], sep="")
  	                 )
  	    if(hasArg(filename)){ 
  	      dev.off()
            }      
          } # if(type %in% c("categorical", "cat"))
      }  # if(!hasArg(type))
    } # if(!hasArg(by.cov))          
  } # end of function plot.gof.pred.rsd
  
  #####################################################
  # Wrap up
  
  out <- list(get.fit = function() .fit,
             print.fit = print.fit,
             get.path = function() .fitpath,
             get.waic = waic,
             plot.trace = plot.trace,
             plot.gof.pred.obs = plot.gof.pred.obs,
             plot.gof.idv = plot.gof.idv,
             plot.gof.pred.rsd = plot.gof.pred.rsd
            )
  class(out) <- "PMXStanFit"
  out
      
} # end of PMXStanFit

#########################################################
### Generic functions

# print parameter statistics
print.PMXStanFit <- function(x, ...)
{
    x$print.fit(...)
}

#' @title Generation of trace plots of the samples
#'
#' @description
#' Plots traces of samples corresponding to one or more Markov chains during the sampling procedure with 
#' \code{Stan}. This is a generic version of the method \code{plot.trace()} for the \code{\link{PMXStanFit}} 
#' class. 
#'
#' @param fit a \code{PMXStanFit} object.
#' @param pars specifies names of parameters whose traces will be plotted. When it is not specified, 
#'             the function by default plots all the \code{theta}'s (model parameters), \code{sigma_eta}'s 
#'             (variance of the inter-individual random effects), and \code{sigma} (variance of the intra-individual
#'             random effects) in the auto-generated Stan code.
#'
#' @seealso
#' \code{\link{PMXStanFit}} for the method \code{plot.trace()}; \code{\link[rstan]{traceplot}} for a similar
#' implementation of plotting traces in \code{rstan}.
#'
#' @examples
#' \dontrun{
#' m <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, model = m)
#' fit <- PMXStanFit(m, dat, iter=100, chains=1)
#'
#' fit$plot.trace()
#' traces(fit)
#' 
#' fit$plot.trace(pars = c("CL", "V", "sigma"))
#' traces(fit, pars = c("CL", "V", "sigma"))
#' }
traces <- function(fit, pars = NULL) 
{   
    UseMethod("traces", fit)
}
traces.PMXStanFit <- function(fit, pars = NULL)
{   
    pdf(file.path(fit$get.path(), "traces.pdf"))
    fit$plot.trace(pars)
    dev.off()
}

#' @title Diagnostics for a fitted Bayesian model
#'
#' @description
#' Calculates diagnostic statistics for a fitted Bayesian model: \emph{Watanabe-Akaike information criterion} 
#' (\code{WAIC}) and \emph{Leave-on-out cross-validation} (\code{LOO-CV}). This is a generic version of the 
#' method \code{get.waic()} for the \code{\link{PMXStanFit}} class. 
#'
#' @param fit a \code{PMXStanFit} object.
#' @param complete a logical to select whether to input all pointwise and total statistics (\code{TRUE}) 
#'                 or only total statistics (\code{FALSE}, by default).
#'
#' @seealso
#' \code{\link{PMXStanFit}} for the method \code{get.waic()}, and related references.
#'
#' @examples
#' \dontrun{
#' m <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, model = m)
#' fit <- PMXStanFit(m, dat, iter=100, chains=1)
#'
#' fit$get.waic()
#' waic(fit, complete = TRUE)
#' }
waic <- function(fit, complete = FALSE) 
{   
    UseMethod("waic", fit)
}
waic.PMXStanFit <- function(fit, complete = FALSE)
{   
    fit$get.waic(complete)
}

#' @title Overall goodness-of-fit plots.
#'
#' @description
#' Provides a convenient and fast way to implement three commonly used goodness-of-fit plotting over the 
#' whole population across the input data.
#'
#' @details
#' This function is a generic version covering three methods for the \code{\link{PMXStanFit}} class: 
#' \code{plot.gof.pred.obs()}, \code{plot.gof.pred.rsd()}, and \code{plot.gof.idv()}. They plot
#' medians of predictions vs. observations, the differences between medians of predicion and observations
#' vs. predictions, and time profiles of the prediction medians and 95% predictive intervals on top
#' of observations, repectively.
#'
#' @param fit a \code{PMXStanFit} object.
#'
#' @seealso
#' \code{\link{PMXStanFit}} for the methods \code{plot.gof.pred.obs()}, \code{plot.gof.idv()}, 
#' and \code{plot.gof.pred.rsd()}.
#'
#' @examples
#' \dontrun{
#' m <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, model = m)
#'
#' fit <- PMXStanFit(m, dat, iter=100, chains=1)
#'
#' fit$plot.gof.pred.obs()
#' fit$plot.gof.idv()
#' fit$plot.gof.pred.rsd()
#' 
#' gofplot(fit)
#' }
gofplot <- function(fit)  # overall gof
{
    UseMethod("gofplot", fit)
}
gofplot.PMXStanFit <- function(fit)
{
    pdf(file.path(fit$get.path(), "overall.gof.pdf"))
    fit$plot.gof.pred.obs()
    fit$plot.gof.idv()
    fit$plot.gof.pred.rsd()
    dev.off()
}

#' @title Pointwise comparison between observed data vs. predicted medians.
#'
#' @description
#' Plots medians of predictions vs. observations for goodness-of-fit assessment, either for the
#' whole population or for subgroups classified by a given covariate. This function is a generic 
#' version of the method \code{plot.gof.pred.obs()} for the \code{\link{PMXStanFit}} class.
#'
#' @param fit a \code{PMXStanFit} object.
#' @param ... additional arguments.
#' @param by.cov a string to specify the parameter name under investigation. If left as
#'               \code{NULL} by default, the plot will be generated based on the overall
#'               population from all individuals in the input data.
#' @param type a string to specify the type of the parameter, can be "categorical" (abbreviated
#'             as "cat") or "continuous" (abbreviated as "con"). This argument is ignored
#'             (automatically set to \code{NULL}) if \emph{by.cov} is not provided by user.
#' @param cutoff a number or vector of numbers to specify the cut-off points by which the
#'               subgroups based on a continuous covariate are defined. If a parameter name
#'               is specified in \emph{by.cov} and the type is specified as "continuous" in
#'               \emph{type}, the default cut-off (when not provided by user) is set as the 
#'               median of the corresponding parameter across the population. This argument
#'               is ignored for a categorical covariate, where the subgroups will be 
#'               automatically determined by all available distinct values of the covariate.
#' @param filename a string to specify the path of the pdf file to store the plots. If not 
#'                 provided by user, the plots will be output to screen automatically.
#'
#' @seealso
#' \code{\link{PMXStanFit}} for the methods \code{plot.gof.pred.obs()}.
#'
#' @examples
#' \dontrun{
#' m <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, model = m)
#'
#' fit <- PMXStanFit(m, dat, iter=100, chains=1)
#'
#' fit$plot.gof.pred.obs(by.cov = "AGE", type = "continuous", cutoff = c(50,60), filename = "obs_pred_by_age.pdf")
#' obs.vs.pred(fit, by.cov = "AGE", type = "continuous", cutoff = c(50, 60), filename = "obs_pred_by_age.pdf")
#' }
obs.vs.pred <- function(fit, ...)
{
    UseMethod("obs.vs.pred", fit)
}
obs.vs.pred.PMXStanFit <- function(fit, ...)
{
    fit$plot.gof.pred.obs(...)
}

#' @title Pointwise comparison between residuals vs. predicted medians.
#'
#' @description
#' Plots lots residuals, calculated as the difference between medians of predictions and observations, 
#' vs. predictions. This function is a generic version of the method \code{plot.gof.pred.rsd()} for the 
#' \code{\link{PMXStanFit}} class.
#' 
#' @param fit a \code{PMXStanFit} object.
#' @param ... additional arguments.
#' @param by.cov a string to specify the parameter name under investigation. If left as
#'               \code{NULL} by default, the plot will be generated based on the overall
#'               population from all individuals in the input data.
#' @param type a string to specify the type of the parameter, can be "categorical" (abbreviated
#'             as "cat") or "continuous" (abbreviated as "con"). This argument is ignored
#'             (automatically set to \code{NULL}) if \emph{by.cov} is not provided by user.
#' @param cutoff a number or vector of numbers to specify the cut-off points by which the
#'               subgroups based on a continuous covariate are defined. If a parameter name
#'               is specified in \emph{by.cov} and the type is specified as "continuous" in
#'               \emph{type}, the default cut-off (when not provided by user) is set as the 
#'               median of the corresponding parameter across the population. This argument
#'               is ignored for a categorical covariate, where the subgroups will be 
#'               automatically determined by all available distinct values of the covariate.
#' @param filename a string to specify the path of the pdf file to store the plots. If not 
#'                 provided by user, the plots will be output to screen automatically.
#'
#' @seealso
#' \code{\link{PMXStanFit}} for the methods \code{plot.gof.pred.rsd()}.
#'
#' @examples
#' \dontrun{
#' m <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, model = m)
#'
#' fit <- PMXStanFit(m, dat, iter=100, chains=1)
#'
#' obs.vs.pred(fit, by.cov = "GENDER", type = "categorical", filename = "obs_pred_by_gender.pdf")
#' rsd.vs.pred(fit, by.cov = "GENDER", type = "categorical", filename = "gof_by_gender.pdf")
#' }
rsd.vs.pred <- function(fit,...)
{
    UseMethod("rsd.vs.pred", fit)
}
rsd.vs.pred.PMXStanFit <- function(fit, ...)
{
    fit$plot.gof.pred.rsd(...)
}

#' @title Comparison of the time profiles of observations with the predictions for each individual.
#'
#' @description
#' plots time profiles of the prediction medians and 95% predictive intervals on top of the time 
#' profile of observations for individual patients. This function is a generic version of the method 
#' \code{plot.gof.idv()} for the \code{\link{PMXStanFit}} class.
#' 
#' @param fit a \code{PMXStanFit} object.
#'
#' @seealso
#' \code{\link{PMXStanFit}} for the methods \code{plot.gof.idv()}.
#'
#' @examples
#' \dontrun{
#' m <- PMXStanModel(path = "pk_m1", pk.struct = "1-cmpt", compile=T)
#'
#' data("examples_data")
#' dat <- prepareInputData(data.source = d1_nm_poppk, model = m)
#'
#' fit <- PMXStanFit(m, dat, iter=100, chains=1)
#'
#' fit$plot.gof.idv()
#' idv.obs.pred.vs.time(fit)
#' }
idv.obs.pred.vs.time <- function(fit)
{
    UseMethod("idv.obs.pred.vs.time", fit)
}
idv.obs.pred.vs.time.PMXStanFit <- function(fit)
{
    fit$plot.gof.idv()
}

#########################################################
### External functions

# Plot observation vs. prediction
plot.pred.obs <- function(yobs,yp,maintext) {
  plot(yp, yobs, xlab = "Prediction", ylab = "Observation",
       main = maintext, col="blue"
      )
  abline(0,1,col="gray", lwd=3)
  df.y <- data.frame(dat.y=yobs, pred.y=yp)
  df.y <- df.y[order(df.y$pred.y),]
  fit.loess <- loess(df.y$dat.y~df.y$pred.y, span = 1)
  lines(df.y$pred.y, predict(fit.loess), col="red", lty=2, lwd=3)
}
  
# Plot residual vs. prediction
plot.pred.rsd <- function(yobs, yp, maintext) {
  res <- yobs - yp
  plot(yp, res, xlab = "Prediction", ylab = "Residual",
       main = maintext, col="blue"
      )
  abline(h=0, col="gray", lwd=3)
  df.res <- data.frame(res=res, pred.y=yp)
  df.res <- df.res[order(df.res$pred.y),]
  fit.loess <- loess(df.res$res~df.res$pred.y, span = 1)
  lines(df.res$pred.y, predict(fit.loess), col="red", lty=2, lwd=3)
}
  
# Calculate columnwise variances
colVars <- function(a) {
  vars <- a[1, ]
  for (n in seq_along(vars)) vars[n] <- var(a[, n])
  return(vars)
}

Try the stanette package in your browser

Any scripts or data that you put into this service are public.

stanette documentation built on May 11, 2022, 5:11 p.m.