R/plot-etiology-regression.R

Defines functions plot_case_study plot_etiology_strat plot_subwt_regression plot_etiology_regression

Documented in plot_case_study plot_etiology_regression plot_etiology_strat plot_subwt_regression

if(getRversion() >= "2.15.1") utils::globalVariables(c("eti_mean","ci_025","ci_975","prob"))


#' visualize the etiology regression with a continuous covariate
#' 
#' This function visualizes the etiology regression against one continuous covariate, e.g., 
#' enrollment date. (NB: dealing with NoA, multiple-pathogen causes, other continuous covariates?
#' also there this function only plots the first slice - so generalization may be useful - give
#' users an option to choose slice s; currently default to the first slice.)
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param stratum_bool a vector of TRUE/FALSE with TRUE indicating the rows of subjects to include
#' @param slice integer; specifies which slice of bronze-standard data to visualize; Default to 1.
#' @param plot_basis TRUE for plotting basis functions; Default to FALSE
#' @param truth a list of truths computed from true parameters in simulations; elements: 
#'  Eti, FPR, PR_case,TPR; All default to `NULL` in real data analyses.
#'  Currently only works for one slice of bronze-standard measurements (in a non-nested model).
#'  \itemize{
#'      \item Eti matrix of # of rows = # of subjects, # columns: `length(cause_list)` for Eti
#'      \item FPR matrix of # of rows = # of subjects, # columns: `ncol(data_nplcm$Mobs$MBS$MBS1)`
#'      \item PR_case matrix of # of rows = # of subjects, # columns: `ncol(data_nplcm$Mobs$MBS$MBS1)`
#'      \item TPR a vector of length identical to `PR_case`
#'  }
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @param do_plot TRUE for plotting
#' @param do_rug TRUE for plotting
#' @param return_metric TRUE for showing overall mean etiology, quantiles, s.d., and if `truth$Eti` is supplied, 
#'  coverage, bias, truth and integrated mean squared errors (IMSE).
#' @param plot_ma_dots plot moving averages among case and controls if TRUE; Default to FALSE.
#' 
#' 
#' @return A figure of etiology regression curves and some marginal positive rate assessment of
#' model fit; See example for the legends.
#' 
#' @import graphics
#' @importFrom lubridate day days_in_month month year
#'        
#' @references See example figures 
#' \itemize{
#' \item A Figure using simulated data for six pathogens: 
#' <https://github.com/zhenkewu/baker/blob/master/inst/figs/visualize_etiology_regression_SITE=1.pdf>
#' \item The legends for the figure above: 
#' <https://github.com/zhenkewu/baker/blob/master/inst/figs/legends_visualize_etiology_regression.png>
#' }
#' @family visualization functions
#'    
plot_etiology_regression <- function(DIR_NPLCM,stratum_bool,slice=1,plot_basis=FALSE,
                                     truth=NULL,RES_NPLCM=NULL,do_plot=TRUE,do_rug=TRUE, 
                                     return_metric=TRUE,
                                     plot_ma_dots=FALSE){
  # only for testing; remove after testing:
  # DIR_NPLCM <- result_folder
  # stratum_bool <- DISCRETE_BOOL
  # plot_basis   <- TRUE
  # discrete_X_names <- c("AGE","ALL_VS") # must be the discrete variables used in Eti_formula.
  # <------------------------------- end of testing.
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm    <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  parsed_model <- assign_model(model_options,data_nplcm)
  is_nested    <- parsed_model$nested
  
  if (do_plot){
    cat("==[baker] plotting etiology regression with 
        >>",c("nested", "non-nested")[2-is_nested],"<< model for 
        BrS Measure slice = ",slice,": ",names(data_nplcm$Mobs$MBS)[[slice]]," .==\n")
  }
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  if (!is.null(RES_NPLCM)){res_nplcm <- RES_NPLCM
  } else {res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                       file.path(DIR_NPLCM,"CODAindex.txt"),
                                       quiet=TRUE)}
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  ncol_dm_FPR <- ncol(bugs.dat[[paste0("Z_FPR_",slice)]]) # design matrix
  JBrS        <- ncol(bugs.dat[[paste0("MBS_",slice)]])
  n_samp_kept   <- nrow(res_nplcm)
  ncol_dm_Eti   <- ncol(bugs.dat$Z_Eti)
  Jcause        <- bugs.dat$Jcause
  Nd            <- bugs.dat$Nd
  Nu            <- bugs.dat$Nu
  K_curr        <- model_options$likelihood$k_subclass[slice]
  templateBS    <- bugs.dat[[paste0("templateBS_",slice)]]
  
  #####################################################################
  # add x-axis for dates:
  X <- data_nplcm$X
  
  if(is.null(X$ENRLDATE)|is.null(X$std_date)){
    stop("'ENRLDATE' and/or 'std_date' is not a variable in your dataset! 
         Make sure that the continuous covariate exists and retry this function.") # can we do other covariates?
  }
  # some date transformations:
  X$date_plot  <- as.Date(X$ENRLDATE)
  X$date_month_centered <- as.Date(cut(X$date_plot,breaks="2 months"))+30
  X$date_month <- as.Date(cut(X$date_plot,breaks="2 months"))
  
  dd <-  as.Date(X$ENRLDATE)
  min_d <- min(dd)
  min_d_std <- unique(X$std_date[which(X$ENRLDATE==min_d)])
  min_plot_d <- min_d+days_in_month(month(min_d))-day(min_d)+1
  
  max_d <- max(dd)
  max_d_std <- unique(X$std_date[which(X$ENRLDATE==max_d)])
  max_plot_d <- max_d-day(max_d)+1
  plot_d <- seq.Date(min_plot_d,max_plot_d,by = "quarter")
  
  unit_x <- (max_d_std-min_d_std)/as.numeric(max_d-min_d)
  plot_d_std <- as.numeric(plot_d - min_d)*unit_x+min_d_std
  
  pred_d <- seq.Date(min_plot_d,max_plot_d,by = "day")
  pred_d_std <- as.numeric(pred_d - min_d)*unit_x+min_d_std
  #####################################################################
  
  # pred_dataframe <- data.frame(ENRLDATE=as.POSIXct.Date(pred_d,tz="UTC"),
  #                              t(replicate(length(pred_d),unlist(unique(X[discrete_names])[1,]))))
  # if (nrow(unique(X[discrete_names]))>1){
  #   for (l in 2:nrow(unique(X[discrete_names]))){
  #     pred_dataframe <- rbind(pred_dataframe,
  #                             data.frame(ENRLDATE=as.POSIXct.Date(pred_d,tz="UTC"),
  #                                        t(replicate(length(pred_d),unlist(unique(X[discrete_names])[l,])))))
  #   }
  # }
  # 
  # 
  # pred_dataframe$std_date <- dm_Rdate_FPR(c(pred_dataframe$ENRLDATE,data_nplcm$X$ENRLDATE),
  #                                         c(rep(1,nrow(pred_dataframe)),data_nplcm$Y),
  #                                         effect = "fixed")[-(1:nrow(data_nplcm$X))]
  # pred_dataframe_ok <- cbind(pred_dataframe,Y=rep(1,nrow(pred_dataframe)))
  # 
  # Z_Eti_pred       <- stats::model.matrix(model_options$likelihood$Eti_formula,
  #                                         pred_dataframe_ok)
  # 
  Z_Eti       <- stats::model.matrix(model_options$likelihood$Eti_formula,
                                     data.frame(data_nplcm$X,Y=data_nplcm$Y)[data_nplcm$Y==1,,drop=FALSE])
  if (!is_nested){
    betaFPR_samp <- array(t(get_res(paste0("^betaFPR_",slice,"\\["))),c(ncol_dm_FPR,JBrS,n_samp_kept))
    betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept),dimnames = list(colnames(Z_Eti),
                                                                                                    model_options$likelihood$cause_list,
                                                                                                    1:n_samp_kept))
    thetaBS_samp <- get_res(paste0("^thetaBS_",slice,"\\["))
    linpred      <- function(beta,design_matrix){design_matrix%*%beta}
    
    out_FPR_linpred     <- array(apply(betaFPR_samp,3,linpred,design_matrix=bugs.dat[[paste0("Z_FPR_",slice)]]),
                                 c(Nd+Nu,JBrS,n_samp_kept))
    out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=bugs.dat$Z_Eti),
                                 c(Nd,Jcause,n_samp_kept))
  } else{
    betaFPR_samp <- array(t(get_res(paste0("^betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
    case_betaFPR_samp <- array(t(get_res(paste0("^case_betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
    betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept),dimnames = list(colnames(Z_Eti),
                                                                                                   model_options$likelihood$cause_list,
                                                                                                   1:n_samp_kept)) #useful in effect estimation.
    ThetaBS_samp <- array(t(get_res(paste0("^ThetaBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
    PsiBS_samp <- array(t(get_res(paste0("^PsiBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
    Eta_samp <- array(t(get_res(paste0("^Eta_",slice,"\\["))),c(Nd,K_curr,n_samp_kept))
    Lambda_samp <- array(t(get_res(paste0("^Lambda_",slice,"\\["))),c(Nu,K_curr,n_samp_kept))
    subwt_samp <- abind::abind(Eta_samp,Lambda_samp,along=1)
    linpred      <- function(beta,design_matrix){design_matrix%*%beta}
    
    # out_caseFPR_linpred     <- array(apply(case_betaFPR_samp,3,linpred,design_matrix=bugs.dat[[paste0("Z_FPR_",slice)]]),
    #                              c(Nd+Nu,K_curr,n_samp_kept))
    out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=bugs.dat$Z_Eti),
                                 c(Nd,Jcause,n_samp_kept)) # can potentially just add pEti to the monitoring.
    #pEti_samp           <- apply(out_Eti_linpred,c(1,3),softmax) # Jcause by Nd by niter.
    pEti_samp           <- abind::abind(aperm(apply(out_Eti_linpred,c(1,3),softmax),c(2,1,3)),
                                        array(0,c(Nu,Jcause,n_samp_kept)),along=1)
    PR_case_ctrl <- compute_marg_PR_nested_reg_array(ThetaBS_array = ThetaBS_samp,PsiBS_array = PsiBS_samp,
                                                     pEti_mat_array = pEti_samp,subwt_mat_array = subwt_samp,
                                                     case = data_nplcm$Y,template = templateBS)
  }
  
  #
  # 2. use this code if date is included in etiology and false positive regressions:
  #
  
  # false positive rates:
  subset_FPR_ctrl     <- data_nplcm$Y==0 & stratum_bool # <----- specifies who to look at. This may be hard to specify if unfamiliar with the data.
  plotid_FPR_ctrl     <- which(subset_FPR_ctrl)[order(data_nplcm$X$std_date[subset_FPR_ctrl])]
  curr_date_FPR       <- data_nplcm$X$std_date[plotid_FPR_ctrl]
  if(!is_nested){
    FPR_prob_scale      <- expit(out_FPR_linpred[plotid_FPR_ctrl,,])
  }else{  FPR_prob_scale      <- PR_case_ctrl[plotid_FPR_ctrl,,]}
  FPR_mean <- apply(FPR_prob_scale,c(1,2),mean)
  FPR_q    <- apply(FPR_prob_scale,c(1,2),quantile,c(0.025,0.975))
  
  # ^ this could be changed theoretically to not use the observed data -> we could still make TPR/FPR plots with the posterior samples
  
  # #
  # # LCM plotting subclass weight curves:
  # #
  # k_seq <- c(1,2,3)                                      # <----------------- adjust order of k.
  # #k_seq <- c(1,5,2,3,4)#1:K # <----------------- adjust order of k.
  # fig_name <- "compare_true_and_estimated_subclass_weight_curves.png"
  # png(file.path(mcmc_options$result.folder,fig_name),width=K_curr*3,height=16,units = "in",res=72)
  # 
  # truth_subwt <- rbind(simu_eta_reordered[data_nplcm$Y==1,],simu_nu_reordered[data_nplcm$Y==0,])
  # if (K_curr > K_truth){truth_subwt <- cbind(truth_subwt,matrix(0,nrow=Nd+Nu, ncol=K_curr > K_truth))}
  # par(mfrow=c(2,K_curr))
  # for (k in seq_along(k_seq)){
  #   # posterior of subclass weight:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],subwt_samp[plotid_FPR_ctrl,k_seq[k],],col=2,type="l",ylim=c(0,1),main=k)
  #   # # posterior of subclass latent Gaussian mean:
  #   # #true subclass weights:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],truth_subwt[plotid_FPR_ctrl,k],type="l",add=TRUE,lwd=4,col=1,lty=c(1,1,1))
  # }
  # 
  # for (k in seq_along(k_seq)){
  #   # posterior of subclass weight:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],t(apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,quantile, c(0.025,0.975))),
  #           col="blue",#c(col1,col2,col3)[k],
  #           type="l",ylim=c(0,1),main=k,lty=2)
  #   points(data_nplcm$X$std_date[plotid_FPR_ctrl],apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,mean),col="black",lty=2,
  #          type="l")
  #   # # posterior of subclass latent Gaussian mean:
  #   # matplot(x,t(res_mu_alpha),col=col3,type="l",main="posterior of latent Gaussian mean")
  #   # # true subclass weights:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],truth_subwt[plotid_FPR_ctrl,k],type="l",add=TRUE,lwd=4,
  #           col=c("black","black","black"),lty=c(1,1,1))
  # }
  # dev.off()
  # #
  # #   <---- END LCM sparse weighs plot!
  # #
  
  # positive rates for cases:
  fitted_margin_case <- function(pEti_ord,theta,psi,template){
    mixture <-  pEti_ord
    tpr     <-  t(t(template)*theta)
    fpr     <- t(t(1-template)*psi)
    colSums(tpr*mixture + fpr*mixture)
  }
  
  Y <- data_nplcm$Y
  subset_FPR_case          <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_case          <- which(subset_FPR_case)[order(data_nplcm$X$std_date[subset_FPR_case])]
  curr_date_FPR_case       <- data_nplcm$X$std_date[plotid_FPR_case]
  if (!is_nested){
    FPR_prob_scale_case      <- expit(out_FPR_linpred[plotid_FPR_case,,])
  }else{FPR_prob_scale_case  <- PR_case_ctrl[plotid_FPR_case,,]}
  
  # etiology:
  subset_Eti <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_Eti <- which(subset_Eti)[order(data_nplcm$X$std_date[subset_Eti])]
  curr_date_Eti  <- data_nplcm$X$std_date[plotid_Eti]
  
  if (!is_nested){
    Eti_prob_scale <- apply(out_Eti_linpred[plotid_Eti,,],c(1,3),softmax)
  }else{Eti_prob_scale <- aperm(pEti_samp,c(2,1,3))[,plotid_Eti,]}
  Eti_mean <- apply(Eti_prob_scale,c(1,2),mean)
  Eti_q    <- apply(Eti_prob_scale,c(1,2),quantile,c(0.025,0.975))
  Eti_overall <- apply(Eti_prob_scale,c(1,3),mean)
  Eti_overall_mean <- rowMeans(Eti_overall)
  Eti_overall_sd   <- apply(Eti_overall,1,sd)
  Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  
  if (!is_nested){
    PR_case <- array(NA,c(length(plotid_Eti),JBrS,n_samp_kept))
    for (i in 1:(length(plotid_Eti))){
      for (t in 1:n_samp_kept){
        PR_case[i,,t] <- fitted_margin_case(Eti_prob_scale[,i,t],
                                            thetaBS_samp[t,],
                                            FPR_prob_scale_case[i,,t],
                                            bugs.dat$templateBS[1:Jcause,]
        )
      }
    }
  } else{PR_case <- PR_case_ctrl[plotid_Eti,,]}
  
  PR_case_mean <- apply(PR_case,c(1,2),mean)
  PR_case_q <- apply(PR_case,c(1,2),quantile,c(0.025,0.975))
  
  ##################
  # plot results:
  #################
  if (do_plot){  
    par(mfcol=c(2,Jcause),oma=c(3,0,3,0))
    for (j in 1:Jcause){ # <--- the marginal dimension of measurements.
      # need to fix this for NoA! <------------------------ FIX!
      #
      # Figure 1 for case and control positive rates:
      #
      par(mar=c(2,5,0,1))
      #<------------------------ FIX!
      if (model_options$likelihood$cause_list[j] == "other"){
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        
        mtext("other",side = 3,cex=1.5,line=1)
      } else if (is.na(match_cause(colnames(data_nplcm$Mobs$MBS[[slice]]),model_options$likelihood$cause_list[j]))) {
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        mtext(model_options$likelihood$cause_list[j],side=3,cex=1.5,line=1)
      } else{                                  #<------------------------ FIX!
        plot(curr_date_FPR,FPR_mean[,j],type="l",ylim=c(0,1),
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        polygon(c(curr_date_FPR, rev(curr_date_FPR)),
                c(FPR_q[1,,j], rev(FPR_q[2,,j])),
                col = grDevices::rgb(0, 1, 1,0.5),border = NA)
        
        # rug plot:
        if(do_rug){
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==1],side=3,col="dodgerblue2",line=0)
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==0],side=1,col="dodgerblue2",line=1)
        }
        
        if(!is.null(truth$FPR)){lines(curr_date_FPR,truth$FPR[plotid_FPR_ctrl,j],col="blue",lwd=3)}
        if(!is.null(truth$TPR)){abline(h=truth$TPR[j],lwd=3,col="black")}
        if(plot_basis){matplot(curr_date_FPR,(bugs.dat[[paste0("Z_FPR_",slice)]])[plotid_FPR_ctrl,],col="blue",type="l",add=TRUE)}
        
        mtext(names(data_nplcm$Mobs$MBS[[1]])[j],side = 3,cex=1.5,line=1)
        
        points(curr_date_FPR_case,PR_case_mean[,j],type="l",ylim=c(0,1))
        polygon(c(curr_date_FPR_case, rev(curr_date_FPR_case)),
                c(PR_case_q[1,,j], rev(PR_case_q[2,,j])),
                col =  grDevices::rgb(1, 0, 0,0.5),border = NA)
        if(!is.null(truth$PR_case)){lines(curr_date_FPR_case,truth$PR_case[plotid_FPR_case,j],col="black",lwd=3)}
        
        # make this optional for plotting
        # rug plot:
        if(do_rug){
          rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==1],side=3,line= 1)
          rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==0],side=1,line= 0)
          
          #labels for the rug plot
          if (j==1){
            mtext(text = "case   -->",side=2,at=line2user(1,3),cex=0.8,las=1)
            mtext(text = "case   -->",side=2,at=line2user(0,1),cex=0.8,las=1)
            mtext(text = "control-->",side=2,at=line2user(0,3), cex=0.8,las=1,col="dodgerblue2")
            mtext(text = "control-->",side=2,at=line2user(1,1), cex=0.8,las=1,col="dodgerblue2")
            
            mtext("1)",side=2,at=0.8,line=3, cex=2,las=1)
          }
        }
        
        if (!is_nested){
          abline(h=colMeans(thetaBS_samp)[j],col="red")
          abline(h=quantile(thetaBS_samp[,j],0.025),col="red",lty=2)
          abline(h=quantile(thetaBS_samp[,j],0.975),col="red",lty=2)
        }
        
        # add raw moving average dots:
        ma <- function(x,n=60){stats::filter(x,rep(1/n,n), sides=2)}
        
        ma_cont <- function(y,x,hw=0.35){
          res <- rep(NA,length(y))
          for (i in seq_along(y)){
            res[i] <- mean(y[which(x>=x[i]-hw & x<=x[i]+hw)])
          }
          res
        }
        response.ctrl <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_ctrl,j]
        dat_ctrl <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_ctrl])[!is.na(response.ctrl),,drop=FALSE]
        dat_ctrl$runmean <- ma_cont(response.ctrl[!is.na(response.ctrl)],dat_ctrl$std_date[!is.na(response.ctrl)])
        if (plot_ma_dots) {points(runmean ~ std_date,data=dat_ctrl[!is.na(response.ctrl),],lty=2,pch=1,cex=0.5,type="o",col="dodgerblue2")}
        
        response.case <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_case,j]
        dat_case <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_case])[!is.na(response.case),,drop=FALSE]
        dat_case$runmean <- ma_cont(response.case[!is.na(response.case)],dat_case$std_date[!is.na(response.case)])
        if (plot_ma_dots){points(runmean ~ std_date,data=dat_case,lty=2,pch=1,cex=0.5,type="o")}
      }
      #
      # Figure 2 for Etiology Regression:
      #
      par(mar=c(2,5,0,1))
      plot(curr_date_Eti,Eti_mean[j,],type="l",ylim=c(0,1),xlab="standardized date",
           ylab=c("","etiologic fraction")[(j==1)+1],bty="n",xaxt="n",yaxt="n",las=2)
      ## ONLY FOR SIMULATIONS <---------------------- FIX!
      if(!is.null(truth$Eti)){
        points(curr_date_Eti,truth$Eti[plotid_Eti,j],type="l",lwd=3,col="black")
        abline(h=colMeans(truth$Eti[data_nplcm$Y==1,])[j],col="blue",lwd=3)
      }
      if(plot_basis){matplot(curr_date_Eti,bugs.dat$Z_Eti[plotid_Eti,],col="blue",type="l",add=TRUE)}
      
      # overall pie:
      abline(h=Eti_overall_mean[j],col="black",lwd=2)
      abline(h=Eti_overall_q[,j],col="black",lty=2,lwd=1.5)
      
      mtext(paste0(round(Eti_overall_mean[j],3)*100,"%"),side=3,line=-2,cex=1.2)
      mtext(paste0(round(Eti_overall_q[1,j],3)*100,"%"),side=3,line=-3,cex=1,adj=0.15)
      mtext(paste0(round(Eti_overall_q[2,j],3)*100,"%"),side=3,line=-3,cex=1,adj=0.85)
      
      if (j==2){
        mtext("<- Overall Pie ->",side=2,at=(line2user(-2,3)+line2user(-1,3))/2,las=1,cex=0.8,col="blue")
        mtext("<- 95% CrI ->",side=2,at=(line2user(-3,3)+line2user(-2,3))/2,las=1,cex=0.8,col="blue")
      }
      
      if (j==1){mtext("2)",side=2,at=0.85,line=3, cex=2,las=1)}
      
      color2 <- grDevices::rgb(190, 190, 190, alpha=200, maxColorValue=255)
      color1 <- grDevices::rgb(216,191,216, alpha=200, maxColorValue=255)
      #cases:
      last_interval <- max(X$date_month)
      lubridate::month(last_interval) <- lubridate::month(last_interval) +2
      # axis(1, X$std_date[c(plotid_FPR_case)], 
      #      format(c(X$date_month[c(plotid_FPR_case)]), "%Y %b"), 
      #      cex.axis = .7,las=2,srt=45)
      rle_res <- rle(year(plot_d))
      format_seq <- rep("%b-%d",length(plot_d))
      format_seq[cumsum(c(1,rle_res$lengths[-length(rle_res$lengths)]))] <- "%Y:%b-%d"
      
      axis(1, plot_d_std,
           format(c(plot_d), 
                  format_seq),
           cex.axis = 0.8,las=2)
      
      axis(2,at = seq(0,1,by=0.2),labels=seq(0,1,by=0.2),las=2)
      
      rug(X$std_date[c(plotid_FPR_case)],side=1,line=-0.2,cex=1)
      
      if (j==1){
        mtext(text = "case   -->",side=2,at=line2user(-0.2,1),cex=0.8,las=1)
      }
      
      polygon(c(curr_date_Eti, rev(curr_date_Eti)),
              c(Eti_q[1,j,], rev(Eti_q[2,j,])),
              col = grDevices::rgb(0.5,0.5,0.5,0.5),border = NA)
    }
  }
  if (return_metric){
    if (!is.null(truth$Eti)){
      Eti_overall_truth  <- colMeans(truth$Eti[plotid_Eti,])
      # Eti_IMSE <- rep(0,length(cause_list))
      # for (t in 1:n_samp_kept){
      #   Eti_IMSE <- Eti_IMSE*(t-1) + apply(Eti_prob_scale[,,t]-t(truth$Eti[plotid_Eti,]),1,function(v) sum(v^2)/length(v))
      #   Eti_IMSE <- Eti_IMSE/t
      # }
      # compute integrated squared error:
      Eti_ISE <- apply(Eti_mean-t(truth$Eti[plotid_Eti,]),1,function(v) sum(v^2)/length(v))
      Eti_overall_cover  <- sapply(seq_along(Eti_overall_mean),
                                   function(s) (Eti_overall_truth[s]<= Eti_overall_q[2,s]) && 
                                     (Eti_overall_truth[s]>= Eti_overall_q[1,s]))
      Eti_overall_bias  <- Eti_overall_mean -  Eti_overall_truth
      
      res <- t(do.call("rbind",make_list(Eti_overall_mean,Eti_overall_sd,Eti_overall_q)))
      rownames(res) <- model_options$likelihood$cause_list
      colnames(res) <- c("post.mean","post.sd","CrI_025","CrI_0975")
      return(make_list(Eti_overall_mean,Eti_overall_q,Eti_overall_sd,
                       Eti_overall_cover, Eti_overall_bias,Eti_overall_truth,Eti_ISE,res,parsed_model))
    } else{
      
      res <- t(do.call("rbind",make_list(Eti_overall_mean,Eti_overall_sd,Eti_overall_q)))
      rownames(res) <- model_options$likelihood$cause_list
      colnames(res) <- c("post.mean","post.sd","CrI_025","CrI_0975")
      
      tt_minus <- sweep(betaEti_samp,c(1,3),betaEti_samp[,Jcause,],"-")
      betaEti_mean  <- apply(tt_minus,c(1,2),mean)
      etaEti_sd  <- apply(tt_minus,c(1,2),sd)
      betaEti_q1  <- apply(tt_minus,c(1,2),quantile,0.025)
      betaEti_q2  <- apply(tt_minus,c(1,2),quantile,0.975)
      beta_res <- make_list(betaEti_mean,etaEti_sd,betaEti_q1,betaEti_q2)
      names(beta_res) <- c("post.mean","post.sd","CrI_025","CrI_0975")
      
      return(make_list(Eti_overall_mean,Eti_overall_q,Eti_overall_sd,res,beta_res,parsed_model))
    }
  }
  
} 


#' visualize the subclass weight regression with a continuous covariate
#' 
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param stratum_bool a vector of TRUE/FALSE with TRUE indicating the rows of subjects to include
#' @param case 1 for plotting cases, 0 for plotting controls; default to 0.
#' @param slice integer; specifies which slice of bronze-standard data to visualize; Default to 1.
#' @param truth a list of truths computed from true parameters in simulations; elements: 
#'  Eti, FPR, PR_case,TPR; All default to `NULL` in real data analyses.
#'  Currently only works for one slice of bronze-standard measurements (in a non-nested model).
#'  \itemize{
#'      \item truth_subwt matrix of # of rows = # of subjects, # columns: number of true subclasses
#' }
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @return A figure of subclass regression curves 
#' @family visualization functions
#'    
plot_subwt_regression <- function(DIR_NPLCM,stratum_bool,case=0,slice=1,truth=NULL,RES_NPLCM=NULL){
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  parsed_model <- assign_model(model_options,data_nplcm)
  is_nested    <- parsed_model$nested
  cat("==[baker] plotting >>",c("case","control")[2-case],"<< subclass weight regression with >> nested << model for BrS Measure slice = ",slice,": ",names(data_nplcm$Mobs$MBS)[[slice]]," .==\n")
  
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  if (!is.null(RES_NPLCM)){res_nplcm <- RES_NPLCM
  } else {res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                       file.path(DIR_NPLCM,"CODAindex.txt"),
                                       quiet=TRUE)}
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  n_samp_kept   <- nrow(res_nplcm)
  Nd            <- bugs.dat$Nd
  Nu            <- bugs.dat$Nu
  K_curr        <- model_options$likelihood$k_subclass[slice]
  
  Eta_samp <- array(t(get_res(paste0("^Eta_",slice,"\\["))),c(Nd,K_curr,n_samp_kept))
  Lambda_samp <- array(t(get_res(paste0("^Lambda_",slice,"\\["))),c(Nu,K_curr,n_samp_kept))
  subwt_samp <- abind::abind(Eta_samp,Lambda_samp,along=1)
  
  if(is.null(data_nplcm$X$std_date)){
    stop("'ENRLDATE' and/or 'std_date' is not a variable in your dataset! Make sure that the continuous covariate exists and retry this function.")
  }
  
  #
  # 2. use this code if date is included in etiology and false positive regressions:
  #
  # false positive rates:
  subset_FPR_ctrl     <- data_nplcm$Y==case & stratum_bool # <--- specifies who to look at.
  plotid_FPR_ctrl     <- which(subset_FPR_ctrl)[order(data_nplcm$X$std_date[subset_FPR_ctrl])]
  #
  # LCM plotting subclass weight curves:
  #
  k_seq <- 1:K_curr #c(1,2,3)  # <----------------- adjust order of k.
  if(!is.null(truth$ord_subclass)){k_seq <- truth$ord_subclass} #c(1,2,3)                                      # <----------------- adjust order of k.
  #k_seq <- c(1,5,2,3,4)#1:K # <----------------- adjust order of k.
  if (!is.null(truth$truth_subwt)){
    K_truth <- ncol(truth$truth_subwt);
    if (K_curr > K_truth){truth_subwt <- cbind(truth_subwt,matrix(0,nrow=Nd+Nu, ncol=K_curr - K_truth))}
  }
  
  ## match the truth subclass weights to the real data 
  if(!is.null(truth$truth_subwt)){
    true_classes = seq_along(k_seq)
    class_to_true_class = rep(0, length(true_classes))
    
    for (class in seq_along(class_to_true_class)) {
      dist_class_true_class = matrix(data=NA, ncol=K_curr, nrow=K_curr)
      for (true_class in true_classes) {
        # find the data that has the highest correlation with the subweights
        dist_class_true_class[class, true_class] <- mean(
          stats::cor(subwt_samp[plotid_FPR_ctrl, k_seq[class], ], truth_subwt[plotid_FPR_ctrl, k_seq[true_class]])
        )
      }
      class_to_true_class[class] <- which.max(dist_class_true_class[class,])
      
    }
    
  }
  
  
  par(mfrow=c(2,K_curr))
  for (k in seq_along(k_seq)){
    # posterior of subclass weight:
    matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],subwt_samp[plotid_FPR_ctrl,k_seq[k],],
            col=2,type="l",ylim=c(0,1),main=k,xlab="scaled date",ylab="subclass weight")
    # # posterior of subclass latent Gaussian mean:
    # #true subclass weights:
    if(!is.null(truth$truth_subwt)){true_k = class_to_true_class[k_seq[k]]}
    if (!is.null(truth$truth_subwt)){matplot(data_nplcm$X$std_date[plotid_FPR_ctrl], truth_subwt[plotid_FPR_ctrl, true_k],
                                             type="l",add=TRUE,lwd=4,col=1,lty=c(1,1,1),xlab="scaled date",ylab="subclass weight")} 
  }
  
  for (k in seq_along(k_seq)){
    # posterior of subclass weight:
    matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],t(apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,quantile, c(0.025,0.975))),
            col="blue",#c(col1,col2,col3)[k],
            type="l",ylim=c(0,1),main=k,lty=2,xlab="scaled date",ylab="subclass weight")
    points(data_nplcm$X$std_date[plotid_FPR_ctrl],apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,mean),col="black",lty=2,
           type="l",xlab="scaled date",ylab="subclass weight")
    # # posterior of subclass latent Gaussian mean:
    # matplot(x,t(res_mu_alpha),col=col3,type="l",main="posterior of latent Gaussian mean")
    # # true subclass weights:
    if(!is.null(truth$truth_subwt)){true_k = class_to_true_class[k_seq[k]]}
    if (!is.null(truth$truth_subwt)){matplot(data_nplcm$X$std_date[plotid_FPR_ctrl], truth_subwt[plotid_FPR_ctrl,true_k],type="l",add=TRUE,lwd=4,
                                             col=c("black","black","black"),lty=c(1,1,1),xlab="scaled date",ylab="subclass weight")}
  }
}


#' visualize the etiology estimates for each discrete levels
#' 
#' This function visualizes the etiology estimates against one discrete covariate, e.g., 
#' age groups. 
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param strata_weights a vector of weights that sum to one; for each pathogen
#' the weights specify how the j-th etiology fraction should be combined across all
#' levels of the discrete predictors in the data; default is `"empirical"`
#' to use empirical weights (observed fractions of subjects across strata).
#' @param truth a list of true values, e.g., 
#' `truth=list(allEti = <a list of etiology fractions, each of identical length - the # of strata; >)`;
#' if available, will be shown in thicker red solid vertical lines.
#' @param RES_NPLCM pre-read `res_nplcm`; default to `NULL`.
#' @param show_levels a vector of integers less than or equal to the total number of 
#' levels of strata; default to `0` for overall.
#' @param is_plot default to TRUE, plotting the figures; if `FALSE` only returning summaries
#' @import graphics ggplot2
#' @importFrom ggpubr ggarrange
#' @importFrom stats aggregate
#' @importFrom reshape2 melt
#' @family visualization functions
#'    
#' @return plotting function
plot_etiology_strat <- function(DIR_NPLCM,strata_weights = "empirical",
                                truth=NULL,
                                RES_NPLCM=NULL,show_levels=0,is_plot=TRUE){
  # ### test
  # DIR_NPLCM = result_folder_discrete
  # strata_weights = "empirical"
  # truth=list(allEti = etiology_allsites)
  # RES_NPLCM = NULL
  # show_levels=0
  # ### test....
  
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder that baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  parsed_model <- assign_model(model_options,data_nplcm)
  is_nested    <- parsed_model$nested
  if(is_plot){cat("==[baker] plotting stratified etiologies with >>",c("nested", "non-nested")[2-is_nested],"<< model==\n")}
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  if (!is.null(RES_NPLCM)){res_nplcm <- RES_NPLCM
  } else {res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                       file.path(DIR_NPLCM,"CODAindex.txt"),
                                       quiet=TRUE)}
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  n_samp_kept   <- nrow(res_nplcm)
  if (is_nested){ncol_dm_Eti   <- ncol(bugs.dat$Z_Eti)}
  Jcause        <- bugs.dat$Jcause
  Nd            <- bugs.dat$Nd
  Nu            <- bugs.dat$Nu
  
  likelihood <- model_options$likelihood
  Y    <- data_nplcm$Y
  X    <- data_nplcm$X
  
  # generate design matrix for etiology regression:
  Eti_formula <- likelihood$Eti_formula
  
  is_discrete_Eti     <- is_discrete(data.frame(X,Y)[Y==1,,drop=FALSE], Eti_formula)
  if (!is_discrete_Eti){stop("==[baker] fitted model does not have all discrete covariates for etiology.==")}
  
  if (!is_nested){
    unique_Eti_level <- dget(file.path(DIR_NPLCM,"unique_Eti_level.txt"))  
    n_unique_Eti_level <- bugs.dat$n_unique_Eti_level  # number of stratums
    
    # etiology:
    plotid_Eti <- which(data_nplcm$Y==1) # <--- specifies who to look at.
    Eti_prob_scale <- array(get_res("pEti"),c(n_samp_kept,n_unique_Eti_level,Jcause))
    
    # posterior etiology mean for each cause for each site
    Eti_mean <- apply(Eti_prob_scale,c(2,3),mean)   
    # posterior etiology quantiles for each cause for each site
    Eti_q    <- apply(Eti_prob_scale,c(2,3),quantile,c(0.025,0.975))
    
    # marginalized posteior etiology ignoring site
    Eti_overall <- apply(Eti_prob_scale,c(3,1),mean)
    # posteior etiology mean for each cause across all sites
    Eti_overall_mean <- rowMeans(Eti_overall)
    # posteior etiology quantiles for each cause across all sites
    Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  }
  
  if(is_nested){
    # design matrix for etiology regression (because this function is 
    # for all discrete predictors, the usual model.matrix is sufficient):
    Z_Eti       <- stats::model.matrix(Eti_formula,data.frame(X,Y)[Y==1,,drop=FALSE])
    # Z_Eti0       <- stats::model.matrix(Eti_formula,data.frame(X,Y)[Y==1,,drop=FALSE])
    # a.eig        <- eigen(t(Z_Eti0)%*%Z_Eti0)
    # sqrt_Z_Eti0  <- a.eig$vectors %*% diag(sqrt(a.eig$values)) %*% solve(a.eig$vectors)
    # 
    # Z_Eti  <- Z_Eti0%*%solve(sqrt_Z_Eti0)
    
    ncol_dm_Eti        <- ncol(Z_Eti)
    Eti_colname_design_mat <- attributes(Z_Eti)$dimnames[[2]]
    attributes(Z_Eti)[names(attributes(Z_Eti))!="dim"] <- NULL 
    # this to prevent issues when JAGS reads in data.
    
    unique_Eti_level   <- unique(Z_Eti)
    n_unique_Eti_level <- nrow(unique_Eti_level)
    Eti_stratum_id     <- apply(Z_Eti,1,function(v) 
      which(rowSums(abs(unique_Eti_level-t(replicate(n_unique_Eti_level,v))))==0))
    rownames(unique_Eti_level) <- 1:n_unique_Eti_level
    colnames(unique_Eti_level) <- Eti_colname_design_mat
    
    #stratum names for etiology regression:
    #dput(unique_Eti_level,file.path(mcmc_options$result.folder,"unique_Eti_level.txt"))
    
    betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept))
    linpred      <- function(beta,design_matrix){design_matrix%*%beta}
    
    # out_caseFPR_linpred     <- array(apply(case_betaFPR_samp,3,linpred,design_matrix=bugs.dat[[paste0("Z_FPR_",slice)]]),
    #                              c(Nd+Nu,K_curr,n_samp_kept))
    out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=unique_Eti_level),
                                 c( nrow(unique_Eti_level),Jcause,n_samp_kept)) # can potentially just add pEti to the monitoring.
    #pEti_samp           <- apply(out_Eti_linpred,c(1,3),softmax) # Jcause by Nd by niter.
    pEti_samp           <- aperm(apply(out_Eti_linpred,c(1,3),softmax),c(2,1,3))
    
    # etiology:
    Eti_prob_scale <- pEti_samp
    Eti_mean <- apply(Eti_prob_scale,c(1,2),mean)
    Eti_q    <- apply(Eti_prob_scale,c(1,2),quantile,c(0.025,0.975))
    
    # Eti_overall <- apply(Eti_prob_scale,c(1,3),mean)
    # Eti_overall_mean <- rowMeans(Eti_overall)
    # Eti_overall_sd   <- apply(Eti_overall,1,sd)
    # Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
    
    Eti_prob_scale <- aperm(Eti_prob_scale,c(3,1,2))
  }
  
  # weight to marginalize posterior etiology distributions across strata
  user_weight <- rep(1/n_unique_Eti_level,n_unique_Eti_level) # c(0.3,0.2,0.1,0.1,0.1,0.1,0.1)
  
  if (!is.null(strata_weights) && strata_weights =="empirical"){
    strat_ind <- match_cause(apply(unique_Eti_level,1,paste,collapse=""),
                             apply(stats::model.matrix(Eti_formula,data_nplcm$X[data_nplcm$Y==1,,drop=FALSE]),1,paste,collapse=""))
    empirical_wt <- rep(NA,nrow(unique_Eti_level))
    for (s in seq_along(empirical_wt)){
      empirical_wt[s] <- mean(strat_ind==s)
    }
    user_weight <- empirical_wt
  } else{
    if(!is.null(strata_weights) && length(strata_weights==n_unique_Eti_level)){
      user_weight <- strata_weights
    }
  }
  
  # marginalized posterior etiology over all sites using user-defined weights
  Eti_overall_usr_weight <- apply(Eti_prob_scale,1,function(S) t(S)%*%matrix(user_weight,ncol=1))
  # marginalized posterior etiology mean using user-defined weights
  Eti_overall_mean_usr_weight <- rowMeans(Eti_overall_usr_weight)
  # marginalized posterior etiology quantiles using user-defined weights
  Eti_overall_q_usr_weight    <- apply(Eti_overall_usr_weight,1,quantile,c(0.025,0.975))
  
  #
  # start plotting:
  #
  plot_list <- list()
  res_list <- list()
  for(site in 1:n_unique_Eti_level){
    
    # shape probabilities array into dataset for ggplot 
    etiData = data.frame(Eti_prob_scale[,site,,drop=FALSE])
    names(etiData) = model_options$likelihood$cause_list
    plotData = melt(etiData,id.vars = NULL)
    names(plotData) = c("cause", "prob")
    
    # compute posterior means and interval end points:
    summaryData <-cbind( aggregate(prob~cause,data=plotData,FUN=mean),
                         aggregate(prob~cause,data=plotData,FUN=quantile,c(0.025,0.975))[,-1])
    
    colnames(summaryData)[c(2,3,4)] <- c("eti_mean","ci_025","ci_975")
    res_list[[site]] <- summaryData
    names(res_list)[site] <- paste0("Posterior distributions of CSCFs for stratum: ", site,"; weight: ",round(user_weight[site],4))
    
    if (is_plot){ 
      ## plot histograms of the posterior probabilities for each stratum 
      plot_list[[site]] <- ggplot(plotData, aes(x=prob)) +
        geom_histogram(fill="#ffc04d",binwidth=0.01) + 
        geom_vline(data=summaryData, 
                   mapping = aes(xintercept=eti_mean), colour="#005b96") +
        geom_vline(data=summaryData, 
                   mapping = aes(xintercept=ci_025), colour="green", linetype="dashed") +
        geom_vline(data=summaryData, 
                   mapping = aes(xintercept=ci_975), colour="green", linetype="dashed") +
        facet_wrap(~cause,ncol=nrow(summaryData)) +
        labs(title=paste0("Posterior distributions of CSCFs for stratum: ", site,"; weight: ",round(user_weight[site],4)),
             subtitle = "Posterior mean displayed as solid line \n 95% CrIs displayed as dashed lines",
             x = "CSCF", y ="Frequency")
      
      
      if (!is.null(truth$allEti)){# add truth if available.
        summaryData$truth <- truth$allEti[[site]]
        plot_list[[site]] <-plot_list[[site]] +
          geom_vline(data=summaryData, 
                     mapping = aes(xintercept=truth), colour="red", linetype="solid",lwd=1)
      }
    }
    
  }
  # shape the overall marginal etiology probabilities as a data frame 
  etiData = data.frame(t(Eti_overall_usr_weight))
  names(etiData) = model_options$likelihood$cause_list
  plotData = melt(etiData,id.vars = NULL)
  names(plotData) = c("cause", "prob")
  
  summaryData <-cbind( aggregate(prob~cause,data=plotData,FUN=mean),
                       aggregate(prob~cause,data=plotData,FUN=quantile,c(0.025,0.975))[,-1])
  
  colnames(summaryData)[c(2,3,4)] <- c("eti_mean","ci_025","ci_975")
  
  res_list[[n_unique_Eti_level+1]] <- summaryData
  names(res_list)[n_unique_Eti_level+1] <-paste0("Posterior distributions of CSCFs (across all levels using weights: (",paste(round(user_weight,3),collapse=","),"))")
  
  if (is_plot){
    ## plot histograms of the posterior probabilities for overall etiology 
    plot_list[[n_unique_Eti_level+1]] <- ggplot(plotData, aes(x=prob)) +
      geom_histogram(fill="#ffc04d",binwidth=0.01) + 
      geom_vline(data=summaryData, 
                 mapping = aes(xintercept=eti_mean), colour="#005b96") +
      geom_vline(data=summaryData, 
                 mapping = aes(xintercept=ci_025), colour="green", linetype="dashed") +
      geom_vline(data=summaryData, 
                 mapping = aes(xintercept=ci_975), colour="green", linetype="dashed") +
      facet_wrap(~cause,ncol=nrow(summaryData)) +
      labs(title=paste0("Posterior distributions of CSCFs (across all levels using weights: (",paste(round(user_weight,3),collapse=","),"))"),
           subtitle = "Posterior mean displayed as solid line \n 95% CrIs displayed as dashed lines",
           x = "CSCF", y ="Frequency")
    
    
    if (!is.null(truth$allEti)){ # add truth if available.
      summaryData$truth <- c(matrix(user_weight,nrow=1)%*%do.call(rbind,truth$allEti))
      plot_list[[n_unique_Eti_level+1]] <- plot_list[[n_unique_Eti_level+1]] +
        geom_vline(data=summaryData, 
                   mapping = aes(xintercept=truth), colour="red",  linetype="solid",lwd=1)
    }
    

    
    if (sum(show_levels>n_unique_Eti_level)>0){stop("==[baker] check 'unique_Eti_level'; `show_levels`
                                              cannot be larger than its number of rows.")}
    if (length(show_levels)==1){
      if (show_levels ==0){plot_list_show <-  plot_list[n_unique_Eti_level+1]
      } else{
        plot_list_show <-  plot_list[show_levels]
      }
    }else{
      if (0 %in% show_levels){
        plot_list_show <- c(plot_list[show_levels[show_levels!=0]],plot_list[n_unique_Eti_level+1])
      }else{
        plot_list_show <- plot_list[show_levels]
      }
    }
    
    print("==[baker] actual meanings of levels (by row):")
    print(unique_Eti_level)
    print(ggpubr::ggarrange(plotlist=plot_list_show,nrow = length(plot_list_show)))
  }

  if (!is_plot){return(make_list(res_list,parsed_model,unique_Eti_level))}
  
  # ggpubr::ggarrange(plotlist=plot_list_show,nrow = length(plot_list_show))
  
  
  # # plot posterior distribution for etiology probability
  # par(mfcol=c(1+n_unique_Eti_level,Jcause),mar=c(3,8,1,0))
  # for (j in 1:Jcause){
  #   # if (j==1) {par(mar=c(3,0,2,0))}
  #   # if (j>1)  {par(mar=c(3,0,1,0))}
  #   for (site in 1:n_unique_Eti_level){
  #     hist(Eti_prob_scale[,site,j],xlim=c(0,1),breaks="Scott",freq=FALSE,main="",xlab="",
  #          ylim=c(0,20))
  #     if (!is.null(truth$allEti)){
  #       abline(v = truth$allEti[[site]][[j]], col="blue", lwd=3, lty=2) # mark the truth.
  #       q_interval <- quantile(Eti_prob_scale[,site,j],c(0.025,0.975))
  #       is_included <- truth$allEti[[site]][[j]] < q_interval[2] && truth$allEti[[site]][[j]] > q_interval[1]
  #       abline(v = q_interval,col=c("gray","red")[2-is_included],lwd=2,lty=1)
  #     }
  # 
  #     if (site==1){mtext(text = likelihood$cause_list[j],3,-1.2,cex=2,adj = 0.9)}
  #     if (j==1){
  #       Lines <- list(bquote(paste("level ",.(levels(as.factor(data_nplcm$X$SITE))[site]))),
  #                     bquote(paste("",.(round(user_weight[site],4)))))
  #       mtext(do.call(expression, Lines),side=2,line=c(5.5,3.75),cex=1.5,col="blue")
  # 
  #       #mtext(paste0(round(user_weight[site],4)),2,5,cex=2,col="blue",las=1)
  #       #if (site==ceiling(n_unique_Eti_level/2)) {mtext("User-specified weight towards overall pie:", 2,12, cex=3)}
  #     }
  #   }
  #   hist(Eti_overall_usr_weight[j,],xlim=c(0,1),breaks="Scott",freq=FALSE,main="",#col="blue",
  #        xlab="Etiology",ylim=c(0,20))
  #   if (j==1){
  #     if (!is.null(strata_weights) && strata_weights=="empirical"){
  #       #mtext("overall pie: empirical weights", 2,5, cex=1)
  # 
  #       Lines <- list(bquote(paste("overall pie (", pi[l],"*)")),
  #                     "empirical weights")
  #       mtext(do.call(expression, Lines),side=2,line=c(5.5,3.5),cex=1.5)
  # 
  #       #mtext(side=2, line=c(5,4), expression(paste("overall pie:",pi[l]," \n empirical weights")))
  #     } else{
  #       mtext("overall pie: ", 2,5, cex=1)
  #     }
  #   }
  #   if (!is.null(truth$allEti)){
  #     truth_overall <- c(matrix(user_weight,nrow=1)%*%do.call(rbind,truth$allEti))
  #     abline(v=truth_overall[j],col="blue",lty=1,lwd=2) # mark the truth.
  #     q_interval_Eti_overall <- quantile(Eti_overall_usr_weight[j,],c(0.025,0.975))
  #     is_included_Eti_overall <- truth_overall[j] < q_interval_Eti_overall[2] &&
  #       truth_overall[j] > q_interval_Eti_overall[1]
  #     abline(v = q_interval_Eti_overall,col=c("gray","red")[2-is_included_Eti_overall],lwd=2,lty=1)
  # 
  #   }
  #   #mtext(text = model_options$likelihood$cause_list[j],3,adj=0.9,cex=2,col="blue")
  # }
}

# truth0     <-list(allEti= as.data.frame(t(compute_pEti(data.frame(siteID=c(1,2)),betaEti0))))
# plot_etiology_strat_nested(DIR_NPLCM,strata_weights = "empirical",
#                            truth=truth0,RES_NPLCM = RES_NPLCM_curr)

#' visualize the PERCH etiology regression with a continuous covariate
#' 
#' This function is specifically designed for PERCH data, e.g., 
#'  (NB: dealing with NoA, multiple-pathogen causes, other continuous covariates?
#' also there this function only plots the first slice - so generalization may be useful - give
#' users an option to choose slice s; currently default to the first slice.)
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param stratum_bool integer; for this function, indicates which strata to plot
#' @param bugs.dat The posterior samples (loaded into the environment to save time) -> default is NULL 
#' @param slice integer; specifies which slice of bronze-standard data to visualize; Default to 1.
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @param do_plot TRUE for plotting
#' @param do_rug TRUE for plotting
#' @param return_metric TRUE for showing overall mean etiology, quantiles, s.d., and if `truth$Eti` is supplied, 
#'  coverage, bias, truth and integrated mean squared errors (IMSE).
#' @return A figure of etiology regression curves and some marginal positive rate assessment of
#' model fit; See example for the legends.
plot_case_study <- function(
  DIR_NPLCM, stratum_bool=stratum_bool, bugs.dat=NULL, slice=1, RES_NPLCM=NULL, do_plot=TRUE, do_rug=FALSE, return_metric=TRUE){
  # only for testing; remove after testing:
  # DIR_NPLCM <- result_folder
  # stratum_bool <- DISCRETE_BOOL
  # discrete_X_names <- c("AGE","ALL_VS") # must be the discrete variables used in Eti_formula.
  # <------------------------------- end of testing.
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  if(model_options$likelihood$k_subclass>1){
    is_nested <- TRUE
  } else{
    is_nested <- FALSE
  }
  
  K_curr        <- model_options$likelihood$k_subclass[slice]
  Jcause        <- length(model_options$likelihood$cause_list)
  Nd            <- sum(data_nplcm$Y==1)
  Nu            <- sum(data_nplcm$Y==0)
  
  
  # structure the posterior samples:
  if(is.null(bugs.dat)){
    new_env <- new.env()
    source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
    bugs.dat <- as.list(new_env)
    rm(new_env)
  } 
  
  
  ncol_dm_FPR <- ncol(bugs.dat[[paste0("Z_FPR_",slice)]]) 
  JBrS        <- 7L #how to get # of measurements
  
  ncol_dm_Eti   <- ncol(bugs.dat$Z_Eti)
  
  templateBS    <- bugs.dat[[paste0("templateBS_",slice)]]
  
  
  if (do_plot){
    cat("==[baker] plotting etiology regression with >>",c("nested", "non-nested")[2-is_nested],"<< model for BrS Measure slice = ",slice,": ",names(data_nplcm$Mobs$MBS)[[slice]]," .==\n")
  }
  
  if (!is.null(RES_NPLCM)){
    res_nplcm <- RES_NPLCM
    
  } else {
    res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                 file.path(DIR_NPLCM,"CODAindex.txt"),
                                 quiet=TRUE)
    
  }
  n_samp_kept   <- nrow(res_nplcm)
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  
  #####################################################################
  ## we could require that people have ENRL
  # add x-axis for dates:
  X <- data_nplcm$X
  Y <- data_nplcm$Y
  # some date transformations:
  X$date_plot  <- as.Date(X$ENRLDATE)
  X$date_month_centered <- as.Date(cut(X$date_plot,breaks="2 months"))+30
  X$date_month <- as.Date(cut(X$date_plot,breaks="2 months"))
  
  dd <-  as.Date(X$ENRLDATE)
  min_d <- min(dd)
  min_d_std <- unique(X$std_date[which(as.Date(X$ENRLDATE)==min_d)])
  min_plot_d <- min_d+days_in_month(month(min_d))-day(min_d)+1
  
  max_d <- max(dd)
  max_d_std <- unique(X$std_date[which(as.Date(X$ENRLDATE)==max_d)])
  max_plot_d <- max_d-day(max_d)+1
  plot_d <- seq.Date(min_plot_d,max_plot_d,by = "quarter")
  
  unit_x <- (max_d_std-min_d_std)/as.numeric(max_d-min_d)
  plot_d_std <- as.numeric(plot_d - min_d)*unit_x+min_d_std
  
  pred_d <- seq.Date(min_plot_d,max_plot_d,by = "day")
  pred_d_std <- as.numeric(pred_d - min_d)*unit_x+min_d_std
  #####################################################################
  
  betaFPR_samp <- array(t(get_res(paste0("^betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
  case_betaFPR_samp <- array(t(get_res(paste0("^case_betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
  betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept)) #useful in effect estimation.
  ThetaBS_samp <- array(t(get_res(paste0("^ThetaBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
  PsiBS_samp <- array(t(get_res(paste0("^PsiBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
  Eta_samp <- array(t(get_res(paste0("^Eta_",slice,"\\["))),c(Nd,K_curr,n_samp_kept))
  Lambda_samp <- array(t(get_res(paste0("^Lambda_",slice,"\\["))),c(Nu,K_curr,n_samp_kept))
  subwt_samp <- abind::abind(Eta_samp,Lambda_samp,along=1)
  linpred      <- function(beta,design_matrix){design_matrix%*%beta}
  
  
  out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=bugs.dat$Z_Eti),
                               c(Nd,Jcause,n_samp_kept)) # can potentially just add pEti to the monitoring.
  
  pEti_samp <-abind::abind(aperm(apply(out_Eti_linpred,c(1,3),softmax),c(2,1,3)),
                           array(0,c(Nu,Jcause,n_samp_kept)),along=1)
  PR_case_ctrl <- compute_marg_PR_nested_reg_array(ThetaBS_array = ThetaBS_samp,PsiBS_array = PsiBS_samp,
                                                   pEti_mat_array = pEti_samp,subwt_mat_array = subwt_samp,
                                                   case = data_nplcm$Y,template = templateBS)
  
  
  
  
  #
  # 2. use this code if date is included in etiology and false positive regressions:
  #
  # false positive rates:
  subset_FPR_ctrl <- data_nplcm$Y==0 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_ctrl <- which(subset_FPR_ctrl)[order(data_nplcm$X$std_date[subset_FPR_ctrl])]
  curr_date_FPR <- data_nplcm$X$std_date[plotid_FPR_ctrl]
  FPR_prob_scale <- PR_case_ctrl[plotid_FPR_ctrl,,]
  
  FPR_mean <- apply(FPR_prob_scale,c(1,2),mean)
  FPR_q    <- apply(FPR_prob_scale,c(1,2),quantile,c(0.025,0.975))
  
  # ^ this could be changed theoretically to not use the observed data -> we could still make TPR/FPR plots with the posterior samples
  
  
  # positive rates for cases:
  fitted_margin_case <- function(pEti_ord,theta,psi,template){
    mixture <-  pEti_ord
    tpr     <-  t(t(template)*theta)
    fpr     <- t(t(1-template)*psi)
    colSums(tpr*mixture + fpr*mixture)
  }
  
  
  Y <- data_nplcm$Y
  subset_FPR_case          <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_case          <- which(subset_FPR_case)[order(data_nplcm$X$std_date[subset_FPR_case])]
  curr_date_FPR_case       <- data_nplcm$X$std_date[plotid_FPR_case]
  FPR_prob_scale_case      <- PR_case_ctrl[plotid_FPR_case,,]
  
  # etiology:
  subset_Eti <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_Eti <- which(subset_Eti)[order(data_nplcm$X$std_date[subset_Eti])]
  curr_date_Eti  <- data_nplcm$X$std_date[plotid_Eti]
  
  ## compute the probabilities and posterior mean/quantiles
  Eti_prob_scale <- aperm(pEti_samp,c(2,1,3))[,plotid_Eti,]
  Eti_mean <- apply(Eti_prob_scale,c(1,2),mean)
  Eti_q    <- apply(Eti_prob_scale,c(1,2),quantile,c(0.025,0.975))
  Eti_overall <- apply(Eti_prob_scale,c(1,3),mean)
  Eti_overall_mean <- rowMeans(Eti_overall)
  Eti_overall_sd   <- apply(Eti_overall,1,sd)
  Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  
  ## for cases
  PR_case <- PR_case_ctrl[plotid_Eti,,]
  PR_case_mean <- apply(PR_case,c(1,2),mean)
  PR_case_q <- apply(PR_case,c(1,2),quantile,c(0.025,0.975))
  
  ##################
  # plot results:
  #################
  if (do_plot){  
    par(mfcol=c(2,Jcause),oma=c(3,0,3,0))
    for (j in 1:Jcause){ # <--- the marginal dimension of measurements.
      # need to fix this for NoA! <------------------------ FIX!
      #
      # Figure 1 for case and control positive rates:
      #
      par(mar=c(2,5,0,1))
      #<------------------------ FIX!
      if (model_options$likelihood$cause_list[j] == "other"){
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        
        mtext("other",side = 3,cex=1.5,line=1)
      } else if (is.na(match_cause(colnames(data_nplcm$Mobs$MBS[[slice]]),model_options$likelihood$cause_list[j]))) {
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        mtext(model_options$likelihood$cause_list[j],side=3,cex=1.5,line=1)
      } else{                                  #<------------------------ FIX!
        plot(curr_date_FPR,FPR_mean[,j],type="l",ylim=c(0,1),
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        polygon(c(curr_date_FPR, rev(curr_date_FPR)),
                c(FPR_q[1,,j], rev(FPR_q[2,,j])),
                col = grDevices::rgb(0, 1, 1,0.5),border = NA)
        
        # rug plot:
        if(do_rug){
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==1],side=3,col="dodgerblue2",line=0)
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==0],side=1,col="dodgerblue2",line=1)
        }
        
        mtext(names(data_nplcm$Mobs$MBS[[1]])[j],side = 3,cex=1.5,line=1)
        
        points(curr_date_FPR_case,PR_case_mean[,j],type="l",ylim=c(0,1))
        polygon(c(curr_date_FPR_case, rev(curr_date_FPR_case)),
                c(PR_case_q[1,,j], rev(PR_case_q[2,,j])),
                col =  grDevices::rgb(1, 0, 0,0.5),border = NA)
        
        # # make this optional for plotting
        # # rug plot:
        # if(do_rug){
        #   rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==1],side=3,line= 1)
        #   rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==0],side=1,line= 0)
        #   
        #   #labels for the rug plot
        if (j==1){
          #     mtext(text = "case   -->",side=2,at=line2user(1,3),cex=0.8,las=1)
          #     mtext(text = "case   -->",side=2,at=line2user(0,1),cex=0.8,las=1)
          #     mtext(text = "control-->",side=2,at=line2user(0,3), cex=0.8,las=1,col="dodgerblue2")
          #     mtext(text = "control-->",side=2,at=line2user(1,1), cex=0.8,las=1,col="dodgerblue2")
          #     
          mtext("1)",side=2,at=0.8,line=3, cex=2,las=1)
        }
        # }
        # 
        
        response.ctrl <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_ctrl,j]
        dat_ctrl <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_ctrl])[!is.na(response.ctrl),,drop=FALSE]
        # dat_ctrl$runmean <- ma_cont(response.ctrl[!is.na(response.ctrl)],dat_ctrl$std_date[!is.na(response.ctrl)])
        
        response.case <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_case,j]
        dat_case <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_case])[!is.na(response.case),,drop=FALSE]
        # dat_case$runmean <- ma_cont(response.case[!is.na(response.case)],dat_case$std_date[!is.na(response.case)])
      }
      #
      # Figure 2 for Etiology Regression:
      #
      par(mar=c(2,5,0,1))
      plot(curr_date_Eti,Eti_mean[j,],type="l",ylim=c(0,1),xlab="standardized date",
           ylab=c("","etiologic fraction")[(j==1)+1],bty="n",xaxt="n",yaxt="n",las=2)
      ## ONLY FOR SIMULATIONS <---------------------- FIX!
      
      # overall pie:
      abline(h=Eti_overall_mean[j],col="black",lwd=2)
      abline(h=Eti_overall_q[,j],col="black",lty=2,lwd=1.5)
      
      mtext(paste0(round(Eti_overall_mean[j],3)*100,"%"),side=3,line=-2,cex=1.2)
      mtext(paste0(round(Eti_overall_q[1,j],3)*100,"%"),side=3,line=-3,cex=0.8,adj=0.15)
      mtext(paste0(round(Eti_overall_q[2,j],3)*100,"%"),side=3,line=-3,cex=0.8,adj=0.85)
      
      if (j==2){
        mtext("<- Overall Pie ->",side=2,at=(line2user(-2,3)+line2user(-1,3))/2,las=1,cex=0.8,col="blue")
        mtext("<- 95% CrI ->",side=2,at=(line2user(-3,3)+line2user(-2,3))/2,las=1,cex=0.8,col="blue")
      }
      
      if (j==1){mtext("2)",side=2,at=0.85,line=3, cex=2,las=1)}
      
      
      color2 <- grDevices::rgb(190, 190, 190, alpha=200, maxColorValue=255)
      color1 <- grDevices::rgb(216,191,216, alpha=200, maxColorValue=255)
      #cases:
      last_interval <- max(X$date_month)
      lubridate::month(last_interval) <- lubridate::month(last_interval) +2
      # axis(1, X$std_date[c(plotid_FPR_case)], 
      #      format(c(X$date_month[c(plotid_FPR_case)]), "%Y %b"), 
      #      cex.axis = .7,las=2,srt=45)
      rle_res <- rle(year(plot_d))
      format_seq <- rep("%b-%d",length(plot_d))
      format_seq[cumsum(c(1,rle_res$lengths[-length(rle_res$lengths)]))] <- "%Y:%b-%d"
      
      axis(1, plot_d_std,
           format(c(plot_d), 
                  format_seq),
           cex.axis = 0.8,las=2)
      
      axis(2,at = seq(0,1,by=0.2),labels=seq(0,1,by=0.2),las=2)
      
      rug(X$std_date[c(plotid_FPR_case)],side=1,line=-0.2,cex=1)
      
      if (j==1){
        mtext(text = "case   -->",side=2,at=line2user(-0.2,1),cex=0.8,las=1)
      }
      
      polygon(c(curr_date_Eti, rev(curr_date_Eti)),
              c(Eti_q[1,j,], rev(Eti_q[2,j,])),
              col = grDevices::rgb(0.5,0.5,0.5,0.5),border = NA)
    }
  }
  if (return_metric){
    return(make_list(Eti_overall_mean,Eti_overall_q,Eti_overall_sd))
  }
}
zhenkewu/baker documentation built on March 17, 2022, 9:54 p.m.