R/plot-etiology-regression.R

Defines functions plot_etiology_strat_nested plot_subwt_regression plot_etiology_strat plot_etiology_regression

Documented in plot_etiology_regression plot_etiology_strat plot_etiology_strat_nested plot_subwt_regression

#' visualize the etiology regression with a continuous covariate
#' 
#' This function visualizes the etiology regression against one continuous covariate, e.g., 
#' enrollment date. (NB: dealing with NoA, multiple-pathogen causes, other continuous covariates?
#' also there this function only plots the first slice - so generalization may be useful - give
#' users an option to choose slice s; currently default to the first slice.)
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param stratum_bool a vector of TRUE/FALSE with TRUE indicating the rows of subjects to include
#' @param slice integer; specifies which slice of bronze-standard data to visualize; Default to 1.
#' @param plot_basis TRUE for plotting basis functions; Default to FALSE
#' @param truth a list of truths computed from true parameters in simulations; elements: 
#'  Eti, FPR, PR_case,TPR; All default to `NULL` in real data analyses.
#'  Currently only works for one slice of bronze-standard measurements (in a non-nested model).
#'  \itemize{
#'      \item Eti matrix of # of rows = # of subjects, # columns: `length(cause_list)` for Eti
#'      \item FPR matrix of # of rows = # of subjects, # columns: `ncol(data_nplcm$Mobs$MBS$MBS1)`
#'      \item PR_case matrix of # of rows = # of subjects, # columns: `ncol(data_nplcm$Mobs$MBS$MBS1)`
#'      \item TPR a vector of length identical to `PR_case`
#'  }
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @param do_plot TRUE for plotting
#' @param do_rug TRUE for plotting
#' @param return_metric TRUE for showing overall mean etiology, quantiles, s.d., and if `truth$Eti` is supplied, 
#'  coverage, bias, truth and integrated mean squared errors (IMSE).
#' @param plot_ma_dots plot moving averages among case and controls if TRUE; Default to FALSE.
#' 
#' 
#' @return A figure of etiology regression curves and some marginal positive rate assessment of
#' model fit; See example for the legends.
#' 
#' @import graphics
#' @importFrom lubridate day days_in_month month year
#' @examples 
#' \dontrun{
#' # legend.text = c("[UPPER FIGURES]",
#'                 "observed prevalence: cases",
#'                 "observed prevalence: controls",
#'                 "fitted prevalence: cases",
#'                 "fitted prevalence: controls",
#'                 "true positive rate: mean",
#'                 "true positive rate: 95%CI",
#'                 "[BOTTOM FIGURES]",
#'                 "etiology curve: mean",
#'                 "overall etiology: mean",
#'                 "overall etiology: 95%CI","","","")
#' legend.col=c("white","black","dodgerblue2","black","dodgerblue2","red","red",
#'              "white","springgreen4","orange","orange","white","white","white")
#' legend.lty=c(1,2,2,1,1,1,2,1,1,1,2)
#' legend.lwd=c(2,2,2,2,2,2,2,2,2,2,2,2,2,2)
#' legend("topleft",legend=legend.text,
#'        lty=legend.lty,lwd=legend.lwd,
#'        col=legend.col,ncol=2,
#'        y.intersp=1.5,cex=1.6,box.col=NA)
#'        }
#'        
#' @references See example figures 
#' \itemize{
#' \item A Figure using simulated data for six pathogens: <https://bit.ly/2FWoYeM>
#' \item The legends for the figure above: <https://bit.ly/2OU8F60>
#' }
#' @family visualization functions
#' @export
plot_etiology_regression <- function(DIR_NPLCM,stratum_bool,slice=1,plot_basis=FALSE,
                                     truth=NULL,RES_NPLCM=NULL,do_plot=TRUE,do_rug=TRUE, 
                                     return_metric=TRUE,
                                     plot_ma_dots=FALSE){
  # only for testing; remove after testing:
  # DIR_NPLCM <- result_folder
  # stratum_bool <- DISCRETE_BOOL
  # plot_basis   <- TRUE
  # discrete_X_names <- c("AGE","ALL_VS") # must be the discrete variables used in Eti_formula.
  # <------------------------------- end of testing.
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm    <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  if(model_options$likelihood$k_subclass>1){
    is_nested <- TRUE
  } else{
    is_nested <- FALSE
  }
  
  if (do_plot){
    cat("==[baker] plotting etiology regression with >>",c("nested", "non-nested")[2-is_nested],"<< model for BrS Measure slice = ",slice,": ",names(data_nplcm$Mobs$MBS)[[slice]]," .==\n")
  }
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  if (!is.null(RES_NPLCM)){res_nplcm <- RES_NPLCM
  } else {res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                       file.path(DIR_NPLCM,"CODAindex.txt"),
                                       quiet=TRUE)}
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  ncol_dm_FPR <- ncol(bugs.dat[[paste0("Z_FPR_",slice)]]) # design matrix
  JBrS        <- ncol(bugs.dat[[paste0("MBS_",slice)]])
  n_samp_kept   <- nrow(res_nplcm)
  ncol_dm_Eti   <- ncol(bugs.dat$Z_Eti)
  Jcause        <- bugs.dat$Jcause
  Nd            <- bugs.dat$Nd
  Nu            <- bugs.dat$Nu
  K_curr        <- model_options$likelihood$k_subclass[slice]
  templateBS    <- bugs.dat[[paste0("templateBS_",slice)]]
  
  #####################################################################
  # add x-axis for dates:
  X <- data_nplcm$X
 
  if(is.null(X$ENRLDATE)|is.null(X$std_date)){
    stop("'ENRLDATE' and/or 'std_date' is not a variable in your dataset! Make sure that the continuous covariate exists and retry this function.")
  }
  # some date transformations:
  X$date_plot  <- as.Date(X$ENRLDATE)
  X$date_month_centered <- as.Date(cut(X$date_plot,breaks="2 months"))+30
  X$date_month <- as.Date(cut(X$date_plot,breaks="2 months"))
  
  dd <-  as.Date(X$ENRLDATE)
  min_d <- min(dd)
  min_d_std <- unique(X$std_date[which(X$ENRLDATE==min_d)])
  min_plot_d <- min_d+days_in_month(month(min_d))-day(min_d)+1
  
  max_d <- max(dd)
  max_d_std <- unique(X$std_date[which(X$ENRLDATE==max_d)])
  max_plot_d <- max_d-day(max_d)+1
  plot_d <- seq.Date(min_plot_d,max_plot_d,by = "quarter")
  
  unit_x <- (max_d_std-min_d_std)/as.numeric(max_d-min_d)
  plot_d_std <- as.numeric(plot_d - min_d)*unit_x+min_d_std
  
  pred_d <- seq.Date(min_plot_d,max_plot_d,by = "day")
  pred_d_std <- as.numeric(pred_d - min_d)*unit_x+min_d_std
  #####################################################################
  
  # pred_dataframe <- data.frame(ENRLDATE=as.POSIXct.Date(pred_d,tz="UTC"),
  #                              t(replicate(length(pred_d),unlist(unique(X[discrete_names])[1,]))))
  # if (nrow(unique(X[discrete_names]))>1){
  #   for (l in 2:nrow(unique(X[discrete_names]))){
  #     pred_dataframe <- rbind(pred_dataframe,
  #                             data.frame(ENRLDATE=as.POSIXct.Date(pred_d,tz="UTC"),
  #                                        t(replicate(length(pred_d),unlist(unique(X[discrete_names])[l,])))))
  #   }
  # }
  # 
  # 
  # pred_dataframe$std_date <- dm_Rdate_FPR(c(pred_dataframe$ENRLDATE,data_nplcm$X$ENRLDATE),
  #                                         c(rep(1,nrow(pred_dataframe)),data_nplcm$Y),
  #                                         effect = "fixed")[-(1:nrow(data_nplcm$X))]
  # pred_dataframe_ok <- cbind(pred_dataframe,Y=rep(1,nrow(pred_dataframe)))
  # 
  # Z_Eti_pred       <- stats::model.matrix(model_options$likelihood$Eti_formula,
  #                                         pred_dataframe_ok)
  # 
  
  if (!is_nested){
    betaFPR_samp <- array(t(get_res(paste0("^betaFPR_",slice,"\\["))),c(ncol_dm_FPR,JBrS,n_samp_kept))
    betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept))
    thetaBS_samp <- get_res(paste0("^thetaBS_",slice,"\\["))
    linpred      <- function(beta,design_matrix){design_matrix%*%beta}
    
    out_FPR_linpred     <- array(apply(betaFPR_samp,3,linpred,design_matrix=bugs.dat[[paste0("Z_FPR_",slice)]]),
                                 c(Nd+Nu,JBrS,n_samp_kept))
    out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=bugs.dat$Z_Eti),
                                 c(Nd,Jcause,n_samp_kept))
  } else{
    betaFPR_samp <- array(t(get_res(paste0("^betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
    case_betaFPR_samp <- array(t(get_res(paste0("^case_betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
    betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept)) #useful in effect estimation.
    ThetaBS_samp <- array(t(get_res(paste0("^ThetaBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
    PsiBS_samp <- array(t(get_res(paste0("^PsiBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
    Eta_samp <- array(t(get_res(paste0("^Eta_",slice,"\\["))),c(Nd,K_curr,n_samp_kept))
    Lambda_samp <- array(t(get_res(paste0("^Lambda_",slice,"\\["))),c(Nu,K_curr,n_samp_kept))
    subwt_samp <- abind::abind(Eta_samp,Lambda_samp,along=1)
    linpred      <- function(beta,design_matrix){design_matrix%*%beta}
    
    # out_caseFPR_linpred     <- array(apply(case_betaFPR_samp,3,linpred,design_matrix=bugs.dat[[paste0("Z_FPR_",slice)]]),
    #                              c(Nd+Nu,K_curr,n_samp_kept))
    out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=bugs.dat$Z_Eti),
                                c(Nd,Jcause,n_samp_kept)) # can potentially just add pEti to the monitoring.
    #pEti_samp           <- apply(out_Eti_linpred,c(1,3),softmax) # Jcause by Nd by niter.
    pEti_samp           <- abind::abind(aperm(apply(out_Eti_linpred,c(1,3),softmax),c(2,1,3)),
                                        array(0,c(Nu,Jcause,n_samp_kept)),along=1)
    PR_case_ctrl <- compute_marg_PR_nested_reg_array(ThetaBS_array = ThetaBS_samp,PsiBS_array = PsiBS_samp,
                                                     pEti_mat_array = pEti_samp,subwt_mat_array = subwt_samp,
                                                     case = data_nplcm$Y,template = templateBS)
    
  }
  
  
  #
  # 2. use this code if date is included in etiology and false positive regressions:
  #
  # false positive rates:
  subset_FPR_ctrl     <- data_nplcm$Y==0 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_ctrl     <- which(subset_FPR_ctrl)[order(data_nplcm$X$std_date[subset_FPR_ctrl])]
  curr_date_FPR       <- data_nplcm$X$std_date[plotid_FPR_ctrl]
  if(!is_nested){
    FPR_prob_scale      <- expit(out_FPR_linpred[plotid_FPR_ctrl,,])
  }else{  FPR_prob_scale      <- PR_case_ctrl[plotid_FPR_ctrl,,]}
  FPR_mean <- apply(FPR_prob_scale,c(1,2),mean)
  FPR_q    <- apply(FPR_prob_scale,c(1,2),quantile,c(0.025,0.975))
  
  # ^ this could be changed theoretically to not use the observed data -> we could still make TPR/FPR plots with the posterior samples
  
  # #
  # # LCM plotting subclass weight curves:
  # #
  # k_seq <- c(1,2,3)                                      # <----------------- adjust order of k.
  # #k_seq <- c(1,5,2,3,4)#1:K # <----------------- adjust order of k.
  # fig_name <- "compare_true_and_estimated_subclass_weight_curves.png"
  # png(file.path(mcmc_options$result.folder,fig_name),width=K_curr*3,height=16,units = "in",res=72)
  # 
  # truth_subwt <- rbind(simu_eta_reordered[data_nplcm$Y==1,],simu_nu_reordered[data_nplcm$Y==0,])
  # if (K_curr > K_truth){truth_subwt <- cbind(truth_subwt,matrix(0,nrow=Nd+Nu, ncol=K_curr > K_truth))}
  # par(mfrow=c(2,K_curr))
  # for (k in seq_along(k_seq)){
  #   # posterior of subclass weight:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],subwt_samp[plotid_FPR_ctrl,k_seq[k],],col=2,type="l",ylim=c(0,1),main=k)
  #   # # posterior of subclass latent Gaussian mean:
  #   # #true subclass weights:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],truth_subwt[plotid_FPR_ctrl,k],type="l",add=TRUE,lwd=4,col=1,lty=c(1,1,1))
  # }
  # 
  # for (k in seq_along(k_seq)){
  #   # posterior of subclass weight:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],t(apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,quantile, c(0.025,0.975))),
  #           col="blue",#c(col1,col2,col3)[k],
  #           type="l",ylim=c(0,1),main=k,lty=2)
  #   points(data_nplcm$X$std_date[plotid_FPR_ctrl],apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,mean),col="black",lty=2,
  #          type="l")
  #   # # posterior of subclass latent Gaussian mean:
  #   # matplot(x,t(res_mu_alpha),col=col3,type="l",main="posterior of latent Gaussian mean")
  #   # # true subclass weights:
  #   matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],truth_subwt[plotid_FPR_ctrl,k],type="l",add=TRUE,lwd=4,
  #           col=c("black","black","black"),lty=c(1,1,1))
  # }
  # dev.off()
  # #
  # #   <---- END LCM sparse weighs plot!
  # #
  
  # positive rates for cases:
  fitted_margin_case <- function(pEti_ord,theta,psi,template){
    mixture <-  pEti_ord
    tpr     <-  t(t(template)*theta)
    fpr     <- t(t(1-template)*psi)
    colSums(tpr*mixture + fpr*mixture)
  }
  
  Y <- data_nplcm$Y
  subset_FPR_case          <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_case          <- which(subset_FPR_case)[order(data_nplcm$X$std_date[subset_FPR_case])]
  curr_date_FPR_case       <- data_nplcm$X$std_date[plotid_FPR_case]
  if (!is_nested){
    FPR_prob_scale_case      <- expit(out_FPR_linpred[plotid_FPR_case,,])
  }else{FPR_prob_scale_case  <- PR_case_ctrl[plotid_FPR_case,,]}
  
  # etiology:
  subset_Eti <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_Eti <- which(subset_Eti)[order(data_nplcm$X$std_date[subset_Eti])]
  curr_date_Eti  <- data_nplcm$X$std_date[plotid_Eti]
  
  if (!is_nested){
    Eti_prob_scale <- apply(out_Eti_linpred[plotid_Eti,,],c(1,3),softmax)
  }else{Eti_prob_scale <- aperm(pEti_samp,c(2,1,3))[,plotid_Eti,]}
  Eti_mean <- apply(Eti_prob_scale,c(1,2),mean)
  Eti_q    <- apply(Eti_prob_scale,c(1,2),quantile,c(0.025,0.975))
  Eti_overall <- apply(Eti_prob_scale,c(1,3),mean)
  Eti_overall_mean <- rowMeans(Eti_overall)
  Eti_overall_sd   <- apply(Eti_overall,1,sd)
  Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  
  if (!is_nested){
    PR_case <- array(NA,c(length(plotid_Eti),JBrS,n_samp_kept))
    for (i in 1:(length(plotid_Eti))){
      for (t in 1:n_samp_kept){
        PR_case[i,,t] <- fitted_margin_case(Eti_prob_scale[,i,t],
                                            thetaBS_samp[t,],
                                            FPR_prob_scale_case[i,,t],
                                            bugs.dat$templateBS[1:Jcause,]
        )
      }
    }
  } else{PR_case <- PR_case_ctrl[plotid_Eti,,]}
  
  PR_case_mean <- apply(PR_case,c(1,2),mean)
  PR_case_q <- apply(PR_case,c(1,2),quantile,c(0.025,0.975))
  
  ##################
  # plot results:
  #################
  if (do_plot){  
    par(mfcol=c(2,Jcause),oma=c(3,0,3,0))
    for (j in 1:Jcause){ # <--- the marginal dimension of measurements.
      # need to fix this for NoA! <------------------------ FIX!
      #
      # Figure 1 for case and control positive rates:
      #
      par(mar=c(2,5,0,1))
      #<------------------------ FIX!
      if (model_options$likelihood$cause_list[j] == "other"){
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        
        mtext("other",side = 3,cex=1.5,line=1)
      } else if (is.na(match_cause(colnames(data_nplcm$Mobs$MBS[[slice]]),model_options$likelihood$cause_list[j]))) {
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        mtext(model_options$likelihood$cause_list[j],side=3,cex=1.5,line=1)
      } else{                                  #<------------------------ FIX!
        plot(curr_date_FPR,FPR_mean[,j],type="l",ylim=c(0,1),
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        polygon(c(curr_date_FPR, rev(curr_date_FPR)),
                c(FPR_q[1,,j], rev(FPR_q[2,,j])),
                col = grDevices::rgb(0, 1, 1,0.5),border = NA)
        
        # rug plot:
        if(do_rug){
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==1],side=3,col="dodgerblue2",line=0)
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==0],side=1,col="dodgerblue2",line=1)
        }
        
        if(!is.null(truth$FPR)){lines(curr_date_FPR,truth$FPR[plotid_FPR_ctrl,j],col="blue",lwd=3)}
        if(!is.null(truth$TPR)){abline(h=truth$TPR[j],lwd=3,col="black")}
        if(plot_basis){matplot(curr_date_FPR,(bugs.dat[[paste0("Z_FPR_",slice)]])[plotid_FPR_ctrl,],col="blue",type="l",add=TRUE)}
        
        mtext(names(data_nplcm$Mobs$MBS[[1]])[j],side = 3,cex=1.5,line=1)
        
        points(curr_date_FPR_case,PR_case_mean[,j],type="l",ylim=c(0,1))
        polygon(c(curr_date_FPR_case, rev(curr_date_FPR_case)),
                c(PR_case_q[1,,j], rev(PR_case_q[2,,j])),
                col =  grDevices::rgb(1, 0, 0,0.5),border = NA)
        if(!is.null(truth$PR_case)){lines(curr_date_FPR_case,truth$PR_case[plotid_FPR_case,j],col="black",lwd=3)}
       
        # make this optional for plotting
         # rug plot:
        if(do_rug){
          rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==1],side=3,line= 1)
          rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==0],side=1,line= 0)
          
          #labels for the rug plot
          if (j==1){
            mtext(text = "case   -->",side=2,at=line2user(1,3),cex=0.8,las=1)
            mtext(text = "case   -->",side=2,at=line2user(0,1),cex=0.8,las=1)
            mtext(text = "control-->",side=2,at=line2user(0,3), cex=0.8,las=1,col="dodgerblue2")
            mtext(text = "control-->",side=2,at=line2user(1,1), cex=0.8,las=1,col="dodgerblue2")
            
            mtext("1)",side=2,at=0.8,line=3, cex=2,las=1)
          }
        }

        if (!is_nested){
          abline(h=colMeans(thetaBS_samp)[j],col="red")
          abline(h=quantile(thetaBS_samp[,j],0.025),col="red",lty=2)
          abline(h=quantile(thetaBS_samp[,j],0.975),col="red",lty=2)
        }
        
        # add raw moving average dots:
        ma <- function(x,n=60){stats::filter(x,rep(1/n,n), sides=2)}
        
        ma_cont <- function(y,x,hw=0.35){
          res <- rep(NA,length(y))
          for (i in seq_along(y)){
            res[i] <- mean(y[which(x>=x[i]-hw & x<=x[i]+hw)])
          }
          res
        }
        response.ctrl <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_ctrl,j]
        dat_ctrl <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_ctrl])[!is.na(response.ctrl),,drop=FALSE]
        dat_ctrl$runmean <- ma_cont(response.ctrl[!is.na(response.ctrl)],dat_ctrl$std_date[!is.na(response.ctrl)])
        if (plot_ma_dots) {points(runmean ~ std_date,data=dat_ctrl[!is.na(response.ctrl),],lty=2,pch=1,cex=0.5,type="o",col="dodgerblue2")}
        
        response.case <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_case,j]
        dat_case <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_case])[!is.na(response.case),,drop=FALSE]
        dat_case$runmean <- ma_cont(response.case[!is.na(response.case)],dat_case$std_date[!is.na(response.case)])
        if (plot_ma_dots){points(runmean ~ std_date,data=dat_case,lty=2,pch=1,cex=0.5,type="o")}
      }
      #
      # Figure 2 for Etiology Regression:
      #
      par(mar=c(2,5,0,1))
      plot(curr_date_Eti,Eti_mean[j,],type="l",ylim=c(0,1),xlab="standardized date",
           ylab=c("","etiologic fraction")[(j==1)+1],bty="n",xaxt="n",yaxt="n",las=2)
      ## ONLY FOR SIMULATIONS <---------------------- FIX!
      if(!is.null(truth$Eti)){
        points(curr_date_Eti,truth$Eti[plotid_Eti,j],type="l",lwd=3,col="black")
        abline(h=colMeans(truth$Eti[data_nplcm$Y==1,])[j],col="blue",lwd=3)
      }
      if(plot_basis){matplot(curr_date_Eti,bugs.dat$Z_Eti[plotid_Eti,],col="blue",type="l",add=TRUE)}
      
      # overall pie:
      abline(h=Eti_overall_mean[j],col="black",lwd=2)
      abline(h=Eti_overall_q[,j],col="black",lty=2,lwd=1.5)
      
      mtext(paste0(round(Eti_overall_mean[j],3)*100,"%"),side=3,line=-2,cex=1.2)
      mtext(paste0(round(Eti_overall_q[1,j],3)*100,"%"),side=3,line=-3,cex=1,adj=0.15)
      mtext(paste0(round(Eti_overall_q[2,j],3)*100,"%"),side=3,line=-3,cex=1,adj=0.85)
      
      if (j==2){
        mtext("<- Overall Pie ->",side=2,at=(line2user(-2,3)+line2user(-1,3))/2,las=1,cex=0.8,col="blue")
        mtext("<- 95% CrI ->",side=2,at=(line2user(-3,3)+line2user(-2,3))/2,las=1,cex=0.8,col="blue")
      }
      
      if (j==1){mtext("2)",side=2,at=0.85,line=3, cex=2,las=1)}
      
      color2 <- grDevices::rgb(190, 190, 190, alpha=200, maxColorValue=255)
      color1 <- grDevices::rgb(216,191,216, alpha=200, maxColorValue=255)
      #cases:
      last_interval <- max(X$date_month)
      lubridate::month(last_interval) <- lubridate::month(last_interval) +2
      # axis(1, X$std_date[c(plotid_FPR_case)], 
      #      format(c(X$date_month[c(plotid_FPR_case)]), "%Y %b"), 
      #      cex.axis = .7,las=2,srt=45)
      rle_res <- rle(year(plot_d))
      format_seq <- rep("%b-%d",length(plot_d))
      format_seq[cumsum(c(1,rle_res$lengths[-length(rle_res$lengths)]))] <- "%Y:%b-%d"
      
      axis(1, plot_d_std,
           format(c(plot_d), 
                  format_seq),
           cex.axis = 0.8,las=2)
      
      axis(2,at = seq(0,1,by=0.2),labels=seq(0,1,by=0.2),las=2)
      
      rug(X$std_date[c(plotid_FPR_case)],side=1,line=-0.2,cex=1)
      
      if (j==1){
        mtext(text = "case   -->",side=2,at=line2user(-0.2,1),cex=0.8,las=1)
      }
      
      polygon(c(curr_date_Eti, rev(curr_date_Eti)),
              c(Eti_q[1,j,], rev(Eti_q[2,j,])),
              col = grDevices::rgb(0.5,0.5,0.5,0.5),border = NA)
    }
  }
  if (return_metric){
    if (!is.null(truth$Eti)){
      Eti_overall_truth  <- colMeans(truth$Eti[plotid_Eti,])
      # Eti_IMSE <- rep(0,length(cause_list))
      # for (t in 1:n_samp_kept){
      #   Eti_IMSE <- Eti_IMSE*(t-1) + apply(Eti_prob_scale[,,t]-t(truth$Eti[plotid_Eti,]),1,function(v) sum(v^2)/length(v))
      #   Eti_IMSE <- Eti_IMSE/t
      # }
      # compute integrated squared error:
      Eti_ISE <- apply(Eti_mean-t(truth$Eti[plotid_Eti,]),1,function(v) sum(v^2)/length(v))
      Eti_overall_cover  <- sapply(seq_along(Eti_overall_mean),
                                   function(s) (Eti_overall_truth[s]<= Eti_overall_q[2,s]) && 
                                     (Eti_overall_truth[s]>= Eti_overall_q[1,s]))
      Eti_overall_bias  <- Eti_overall_mean -  Eti_overall_truth
      return(make_list(Eti_overall_mean,Eti_overall_q,Eti_overall_sd,
                       Eti_overall_cover, Eti_overall_bias,Eti_overall_truth,Eti_ISE))
    } else{
      return(make_list(Eti_overall_mean,Eti_overall_q,Eti_overall_sd))
    }
  }
  
} 


#' visualize the etiology estimates for each discrete levels
#' 
#' This function visualizes the etiology estimates against one discrete covariate, e.g., 
#' age groups. 
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param strata_weights a vector of weights that sum to one; for each pathogen
#' the weights specify how the j-th etiology fraction should be combined across all
#' levels of the discrete predictors in the data
#' @param truth a list of true values, e.g., 
#' `truth=list(allEti = a list of etiology fractions)`
#' 
#' @import graphics
#' @family visualization functions
#' @export
plot_etiology_strat <- function(DIR_NPLCM,strata_weights=NULL,truth=NULL){
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] oops, not a folder baker recognizes. 
         Try a folder generated by baker, e.g., a temporary folder?==")
  }
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))  
  
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                               file.path(DIR_NPLCM,"CODAindex.txt"),
                               quiet=TRUE)
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  JBrS_1        <- ncol(bugs.dat$MBS_1) # number of pathogens in MBS1
  n_samp_kept   <- nrow(res_nplcm)      # number of posterior sample after burn-in
  Jcause        <- bugs.dat$Jcause      # number of causes
  Nd            <- bugs.dat$Nd          # case size
  Nu            <- bugs.dat$Nu          # control size
  n_unique_Eti_level <- bugs.dat$n_unique_Eti_level  # number of stratums
  
  # etiology:
  plotid_Eti <- which(data_nplcm$Y==1) # <--- specifies who to look at.
  Eti_prob_scale <- array(get_res("pEti"),c(n_samp_kept,n_unique_Eti_level,Jcause))
  
  # posterior etiology mean for each cause for each site
  Eti_mean <- apply(Eti_prob_scale,c(2,3),mean)   
  # posterior etiology quantiles for each cause for each site
  Eti_q    <- apply(Eti_prob_scale,c(2,3),quantile,c(0.025,0.975))
  
  # marginalized posteior etiology ignoring site
  Eti_overall <- apply(Eti_prob_scale,c(3,1),mean)
  # posteior etiology mean for each cause across all sites
  Eti_overall_mean <- rowMeans(Eti_overall)
  # posteior etiology quantiles for each cause across all sites
  Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  
  
  # weight to marginalize posterior etiology distributions
  user_weight <- rep(1/n_unique_Eti_level,n_unique_Eti_level) # c(0.3,0.2,0.1,0.1,0.1,0.1,0.1)
  if(!is.null(strata_weights) && length(strata_weights==n_unique_Eti_level)){
    user_weight <- strata_weights
  }
  # marginalized posterior etiology over all sites using user-defined weights
  Eti_overall_usr_weight <- apply(Eti_prob_scale,1,function(S) t(S)%*%matrix(user_weight,ncol=1))
  # marginalized posterior etiology mean using user-defined weights
  Eti_overall_mean_usr_weight <- rowMeans(Eti_overall_usr_weight)
  # marginalized posterior etiology quantiles using user-defined weights
  Eti_overall_q_usr_weight    <- apply(Eti_overall_usr_weight,1,quantile,c(0.025,0.975))
  
  # plot posterior distribution for etiology probability
  par(mfcol=c(1+n_unique_Eti_level,Jcause),mar=c(3,10,1,0))
  for (j in 1:Jcause){
    # if (j==1) {par(mar=c(3,0,2,0))}
    # if (j>1)  {par(mar=c(3,0,1,0))}
    for (site in 1:n_unique_Eti_level){
      hist(Eti_prob_scale[,site,j],xlim=c(0,1),breaks="Scott",freq=FALSE,main="",xlab="")
      if (!is.null(truth$allEti)){
        abline(v = truth$allEti[[site]][[j]], col="blue", lwd=3, lty=2) # mark the truth.
        q_interval <- quantile(Eti_prob_scale[,site,j],c(0.025,0.975))
        is_included <- truth$allEti[[site]][[j]] < q_interval[2] && truth$allEti[[site]][[j]] > q_interval[1]
        abline(v = q_interval,col=c("gray","red")[2-is_included],lwd=2,lty=1)
      }
      mtext(text = paste0('level',levels(as.factor(data_nplcm$X$SITE))[site],": ",
                          Jcause[j]),3,-1,cex=2,adj = 0.9)
      if (j==1){
        mtext(paste0(round(user_weight[site],4)),2,5,cex=2,col="blue",las=1)
        if (site==ceiling(n_unique_Eti_level/2)) {mtext("User-specified weight towards overall pie:", 2,12, cex=3)}
      }
    }
    hist(Eti_overall_usr_weight[j,],xlim=c(0,1),breaks="Scott",freq=FALSE,main="",col="blue",
         xlab="Etiology")
    if (!is.null(truth$allEti)){
      truth_overall <- c(matrix(user_weight,nrow=1)%*%do.call(rbind,truth$allEti))
      abline(v=truth_overall[site],col="orange",lty=1,lwd=2) # mark the truth.
      q_interval_Eti_overall <- quantile(Eti_overall_usr_weight[j,],c(0.025,0.975))
      is_included_Eti_overall <- truth_overall < q_interval_Eti_overall[2] && 
        truth_overall > q_interval_Eti_overall[1]
      abline(v = q_interval_Eti_overall,col=c("gray","red")[2-is_included_Eti_overall],lwd=2,lty=1)
      
    }
    mtext(text = paste0("Overall: ",model_options$likelihood$cause_list[j]),3,adj=0.9,cex=2,col="blue")
  }
}

#' visualize the subclass weight regression with a continuous covariate
#' 
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param stratum_bool a vector of TRUE/FALSE with TRUE indicating the rows of subjects to include
#' @param case 1 for plotting cases, 0 for plotting controls; default to 0.
#' @param slice integer; specifies which slice of bronze-standard data to visualize; Default to 1.
#' @param truth a list of truths computed from true parameters in simulations; elements: 
#'  Eti, FPR, PR_case,TPR; All default to `NULL` in real data analyses.
#'  Currently only works for one slice of bronze-standard measurements (in a non-nested model).
#'  \itemize{
#'      \item truth_subwt matrix of # of rows = # of subjects, # columns: number of true subclasses
#' }
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @return A figure of subclass regression curves 
#' @family visualization functions
#' @export
plot_subwt_regression <- function(DIR_NPLCM,stratum_bool,case=0,slice=1,truth=NULL,RES_NPLCM=NULL){
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  parsed_model <- assign_model(model_options,data_nplcm)
  is_nested    <- parsed_model$nested
  cat("==[baker] plotting >>",c("case","control")[2-case],"<< subclass weight regression with >> nested << model for BrS Measure slice = ",slice,": ",names(data_nplcm$Mobs$MBS)[[slice]]," .==\n")
  
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  if (!is.null(RES_NPLCM)){res_nplcm <- RES_NPLCM
  } else {res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                       file.path(DIR_NPLCM,"CODAindex.txt"),
                                       quiet=TRUE)}
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  n_samp_kept   <- nrow(res_nplcm)
  Nd            <- bugs.dat$Nd
  Nu            <- bugs.dat$Nu
  K_curr        <- model_options$likelihood$k_subclass[slice]
  
  Eta_samp <- array(t(get_res(paste0("^Eta_",slice,"\\["))),c(Nd,K_curr,n_samp_kept))
  Lambda_samp <- array(t(get_res(paste0("^Lambda_",slice,"\\["))),c(Nu,K_curr,n_samp_kept))
  subwt_samp <- abind::abind(Eta_samp,Lambda_samp,along=1)
  
  if(is.null(data_nplcm$X$std_date)){
    stop("'ENRLDATE' and/or 'std_date' is not a variable in your dataset! Make sure that the continuous covariate exists and retry this function.")
  }
  
  #
  # 2. use this code if date is included in etiology and false positive regressions:
  #
  # false positive rates:
  subset_FPR_ctrl     <- data_nplcm$Y==case & stratum_bool # <--- specifies who to look at.
  plotid_FPR_ctrl     <- which(subset_FPR_ctrl)[order(data_nplcm$X$std_date[subset_FPR_ctrl])]
  #
  # LCM plotting subclass weight curves:
  #
  k_seq <- 1:K_curr #c(1,2,3)  # <----------------- adjust order of k.
  if(!is.null(truth$ord_subclass)){k_seq <- truth$ord_subclass} #c(1,2,3)                                      # <----------------- adjust order of k.
  #k_seq <- c(1,5,2,3,4)#1:K # <----------------- adjust order of k.
  if (!is.null(truth$truth_subwt)){
    K_truth <- ncol(truth$truth_subwt);
    if (K_curr > K_truth){truth_subwt <- cbind(truth_subwt,matrix(0,nrow=Nd+Nu, ncol=K_curr - K_truth))}
  }
  
  ## match the truth subclass weights to the real data 
  if(!is.null(truth$truth_subwt)){
    true_classes = seq_along(k_seq)
    class_to_true_class = rep(0, length(true_classes))

    for (class in seq_along(class_to_true_class)) {
      dist_class_true_class = matrix(data=NA, ncol=K_curr, nrow=K_curr)
      for (true_class in true_classes) {
        # find the data that has the highest correlation with the subweights
        dist_class_true_class[class, true_class] <- mean(
          stats::cor(subwt_samp[plotid_FPR_ctrl, k_seq[class], ], truth_subwt[plotid_FPR_ctrl, k_seq[true_class]])
        )
      }
      class_to_true_class[class] <- which.max(dist_class_true_class[class,])
      
    }
    
  }
  
  
  par(mfrow=c(2,K_curr))
  for (k in seq_along(k_seq)){
    # posterior of subclass weight:
    matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],subwt_samp[plotid_FPR_ctrl,k_seq[k],],
            col=2,type="l",ylim=c(0,1),main=k,xlab="scaled date",ylab="subclass weight")
    # # posterior of subclass latent Gaussian mean:
    # #true subclass weights:
    true_k = class_to_true_class[k_seq[k]]
    if (!is.null(truth$truth_subwt)){matplot(data_nplcm$X$std_date[plotid_FPR_ctrl], truth_subwt[plotid_FPR_ctrl, true_k],
                                             type="l",add=TRUE,lwd=4,col=1,lty=c(1,1,1),xlab="scaled date",ylab="subclass weight")} 
  }
  
  for (k in seq_along(k_seq)){
    # posterior of subclass weight:
    matplot(data_nplcm$X$std_date[plotid_FPR_ctrl],t(apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,quantile, c(0.025,0.975))),
            col="blue",#c(col1,col2,col3)[k],
            type="l",ylim=c(0,1),main=k,lty=2,xlab="scaled date",ylab="subclass weight")
    points(data_nplcm$X$std_date[plotid_FPR_ctrl],apply(subwt_samp[plotid_FPR_ctrl,k_seq[k],],1,mean),col="black",lty=2,
           type="l",xlab="scaled date",ylab="subclass weight")
    # # posterior of subclass latent Gaussian mean:
    # matplot(x,t(res_mu_alpha),col=col3,type="l",main="posterior of latent Gaussian mean")
    # # true subclass weights:
    true_k = class_to_true_class[k_seq[k]]
    if (!is.null(truth$truth_subwt)){matplot(data_nplcm$X$std_date[plotid_FPR_ctrl], truth_subwt[plotid_FPR_ctrl,true_k],type="l",add=TRUE,lwd=4,
                                             col=c("black","black","black"),lty=c(1,1,1),xlab="scaled date",ylab="subclass weight")}
  }
}


#' visualize the etiology estimates for each discrete levels
#' 
#' This function visualizes the etiology estimates against one discrete covariate, e.g., 
#' age groups. 
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param strata_weights a vector of weights that sum to one; for each pathogen
#' the weights specify how the j-th etiology fraction should be combined across all
#' levels of the discrete predictors in the data; can also specify as `"empirical"`
#' to use empirical weights (fractions of subjects in each stratum)
#' @param truth a list of true values, e.g., 
#' `truth=list(allEti = a list of etiology fractions)`
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @import graphics
#' @family visualization functions
#' @export

plot_etiology_strat_nested <- function(DIR_NPLCM,strata_weights = NULL,
                                       truth=NULL,RES_NPLCM=NULL){
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  parsed_model <- assign_model(model_options,data_nplcm)
  is_nested    <- parsed_model$nested
  cat("==[baker] plotting etiology regression with >>",c("nested", "non-nested")[2-is_nested],"<< model==\n")
  new_env <- new.env()
  source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
  bugs.dat <- as.list(new_env)
  rm(new_env)
  if (!is.null(RES_NPLCM)){res_nplcm <- RES_NPLCM
  } else {res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                       file.path(DIR_NPLCM,"CODAindex.txt"),
                                       quiet=TRUE)}
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  # structure the posterior samples:
  n_samp_kept   <- nrow(res_nplcm)
  ncol_dm_Eti   <- ncol(bugs.dat$Z_Eti)
  Jcause        <- bugs.dat$Jcause
  Nd            <- bugs.dat$Nd
  Nu            <- bugs.dat$Nu
  
  likelihood <- model_options$likelihood
  Y    <- data_nplcm$Y
  X    <- data_nplcm$X
  
  # generate design matrix for etiology regression:
  Eti_formula <- likelihood$Eti_formula
  is_discrete_Eti     <- is_discrete(data.frame(X,Y)[Y==1,,drop=FALSE], Eti_formula)
  
  # design matrix for etiology regression (because this function is 
  # for all discrete predictors, the usual model.matrix is sufficient):
  Z_Eti       <- stats::model.matrix(Eti_formula,data.frame(X,Y)[Y==1,,drop=FALSE])
  # Z_Eti0       <- stats::model.matrix(Eti_formula,data.frame(X,Y)[Y==1,,drop=FALSE])
  # a.eig        <- eigen(t(Z_Eti0)%*%Z_Eti0)
  # sqrt_Z_Eti0  <- a.eig$vectors %*% diag(sqrt(a.eig$values)) %*% solve(a.eig$vectors)
  # 
  # Z_Eti  <- Z_Eti0%*%solve(sqrt_Z_Eti0)
  
  ncol_dm_Eti        <- ncol(Z_Eti)
  Eti_colname_design_mat <- attributes(Z_Eti)$dimnames[[2]]
  attributes(Z_Eti)[names(attributes(Z_Eti))!="dim"] <- NULL 
  # this to prevent issues when JAGS reads in data.
  
  unique_Eti_level   <- unique(Z_Eti)
  n_unique_Eti_level <- nrow(unique_Eti_level)
  Eti_stratum_id     <- apply(Z_Eti,1,function(v) 
    which(rowSums(abs(unique_Eti_level-t(replicate(n_unique_Eti_level,v))))==0))
  rownames(unique_Eti_level) <- 1:n_unique_Eti_level
  colnames(unique_Eti_level) <- Eti_colname_design_mat
  
  #stratum names for etiology regression:
  #dput(unique_Eti_level,file.path(mcmc_options$result.folder,"unique_Eti_level.txt"))
  
  
  betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept))
  linpred      <- function(beta,design_matrix){design_matrix%*%beta}
  
  # out_caseFPR_linpred     <- array(apply(case_betaFPR_samp,3,linpred,design_matrix=bugs.dat[[paste0("Z_FPR_",slice)]]),
  #                              c(Nd+Nu,K_curr,n_samp_kept))
  out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=unique_Eti_level),
                               c( nrow(unique_Eti_level),Jcause,n_samp_kept)) # can potentially just add pEti to the monitoring.
  #pEti_samp           <- apply(out_Eti_linpred,c(1,3),softmax) # Jcause by Nd by niter.
  pEti_samp           <- aperm(apply(out_Eti_linpred,c(1,3),softmax),c(2,1,3))
  
  # etiology:
  Eti_prob_scale <- pEti_samp
  Eti_mean <- apply(Eti_prob_scale,c(1,2),mean)
  Eti_q    <- apply(Eti_prob_scale,c(1,2),quantile,c(0.025,0.975))
  Eti_overall <- apply(Eti_prob_scale,c(1,3),mean)
  Eti_overall_mean <- rowMeans(Eti_overall)
  Eti_overall_sd   <- apply(Eti_overall,1,sd)
  Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  
  Eti_prob_scale <- aperm(Eti_prob_scale,c(3,1,2))
  # weight to marginalize posterior etiology distributions
  user_weight <- rep(1/n_unique_Eti_level,n_unique_Eti_level) # c(0.3,0.2,0.1,0.1,0.1,0.1,0.1)
  
  if (!is.null(strata_weights) && strata_weights =="empirical"){
    strat_ind <- match_cause(apply(unique_Eti_level,1,paste,collapse=""),
                             apply(stats::model.matrix(Eti_formula,data_nplcm$X[data_nplcm$Y==1,]),1,paste,collapse=""))
    empirical_wt <- rep(NA,nrow(unique_Eti_level))
    for (s in seq_along(empirical_wt)){
      empirical_wt[s] <- mean(strat_ind==s)
    }
    user_weight <- empirical_wt
  } else{
    if(!is.null(strata_weights) && length(strata_weights==n_unique_Eti_level)){
      user_weight <- strata_weights
    }
  }
  
  # marginalized posterior etiology over all sites using user-defined weights
  Eti_overall_usr_weight <- apply(Eti_prob_scale,1,function(S) t(S)%*%matrix(user_weight,ncol=1))
  # marginalized posterior etiology mean using user-defined weights
  Eti_overall_mean_usr_weight <- rowMeans(Eti_overall_usr_weight)
  # marginalized posterior etiology quantiles using user-defined weights
  Eti_overall_q_usr_weight    <- apply(Eti_overall_usr_weight,1,quantile,c(0.025,0.975))
  
  # plot posterior distribution for etiology probability
  par(mfcol=c(1+n_unique_Eti_level,Jcause),mar=c(3,8,1,0))
  for (j in 1:Jcause){
    # if (j==1) {par(mar=c(3,0,2,0))}
    # if (j>1)  {par(mar=c(3,0,1,0))}
    for (site in 1:n_unique_Eti_level){
      hist(Eti_prob_scale[,site,j],xlim=c(0,1),breaks="Scott",freq=FALSE,main="",xlab="",
           ylim=c(0,20))
      if (!is.null(truth$allEti)){
        abline(v = truth$allEti[[site]][[j]], col="blue", lwd=3, lty=2) # mark the truth.
        q_interval <- quantile(Eti_prob_scale[,site,j],c(0.025,0.975))
        is_included <- truth$allEti[[site]][[j]] < q_interval[2] && truth$allEti[[site]][[j]] > q_interval[1]
        abline(v = q_interval,col=c("gray","red")[2-is_included],lwd=2,lty=1)
      }
      
      if (site==1){mtext(text = likelihood$cause_list[j],3,-1.2,cex=2,adj = 0.9)}
      if (j==1){
        Lines <- list(bquote(paste("level ",.(levels(as.factor(data_nplcm$X$SITE))[site]))),
                      bquote(paste("",.(round(user_weight[site],4)))))
        mtext(do.call(expression, Lines),side=2,line=c(5.5,3.75),cex=1.5,col="blue")
        
        #mtext(paste0(round(user_weight[site],4)),2,5,cex=2,col="blue",las=1)
        #if (site==ceiling(n_unique_Eti_level/2)) {mtext("User-specified weight towards overall pie:", 2,12, cex=3)}
      }
    }
    hist(Eti_overall_usr_weight[j,],xlim=c(0,1),breaks="Scott",freq=FALSE,main="",#col="blue",
         xlab="Etiology",ylim=c(0,20))
    if (j==1){
      if (!is.null(strata_weights) && strata_weights=="empirical"){
        #mtext("overall pie: empirical weights", 2,5, cex=1)
        
        Lines <- list(bquote(paste("overall pie (", pi[l],"*)")),
                      "empirical weights")
        mtext(do.call(expression, Lines),side=2,line=c(5.5,3.5),cex=1.5)
        
        #mtext(side=2, line=c(5,4), expression(paste("overall pie:",pi[l]," \n empirical weights"))) 
      } else{
        mtext("overall pie: ", 2,5, cex=1)
      }
    }
    if (!is.null(truth$allEti)){
      truth_overall <- c(matrix(user_weight,nrow=1)%*%do.call(rbind,truth$allEti))
      abline(v=truth_overall[j],col="blue",lty=1,lwd=2) # mark the truth.
      q_interval_Eti_overall <- quantile(Eti_overall_usr_weight[j,],c(0.025,0.975))
      is_included_Eti_overall <- truth_overall[j] < q_interval_Eti_overall[2] &&
        truth_overall[j] > q_interval_Eti_overall[1]
      abline(v = q_interval_Eti_overall,col=c("gray","red")[2-is_included_Eti_overall],lwd=2,lty=1)
      
    }
    #mtext(text = model_options$likelihood$cause_list[j],3,adj=0.9,cex=2,col="blue")
  }
}

# truth0     <-list(allEti= as.data.frame(t(compute_pEti(data.frame(siteID=c(1,2)),betaEti0))))
# plot_etiology_strat_nested(DIR_NPLCM,strata_weights = "empirical",
#                            truth=truth0,RES_NPLCM = RES_NPLCM_curr)

#' visualize the PERCH etiology regression with a continuous covariate
#' 
#' This function is specifically designed for PERCH data, e.g., 
#'  (NB: dealing with NoA, multiple-pathogen causes, other continuous covariates?
#' also there this function only plots the first slice - so generalization may be useful - give
#' users an option to choose slice s; currently default to the first slice.)
#' 
#' @param DIR_NPLCM File path to the folder containing posterior samples
#' @param stratum_bool integer; for this function, indicates which strata to plot
#' @param bugs.dat The posterior samples (loaded into the environment to save time) -> default is NULL 
#' @param slice integer; specifies which slice of bronze-standard data to visualize; Default to 1.
#' @param RES_NPLCM pre-read res_nplcm; default to NULL.
#' @param do_plot TRUE for plotting
#' @param do_rug TRUE for plotting
#' @param return_metric TRUE for showing overall mean etiology, quantiles, s.d., and if `truth$Eti` is supplied, 
#'  coverage, bias, truth and integrated mean squared errors (IMSE).
#' @return A figure of etiology regression curves and some marginal positive rate assessment of
#' model fit; See example for the legends.
plot_case_study <- function(
  DIR_NPLCM, stratum_bool=stratum_bool, bugs.dat=NULL, slice=1, RES_NPLCM=NULL, do_plot=TRUE, do_rug=FALSE, return_metric=TRUE){
  # only for testing; remove after testing:
  # DIR_NPLCM <- result_folder
  # stratum_bool <- DISCRETE_BOOL
  # discrete_X_names <- c("AGE","ALL_VS") # must be the discrete variables used in Eti_formula.
  # <------------------------------- end of testing.
  old_par <- graphics::par(graphics::par("mfrow", "mar"))
  on.exit(graphics::par(old_par))
  if (!is_jags_folder(DIR_NPLCM)){
    stop("==[baker] Oops, not a folder baker recognizes. Try a folder generated by baker, e.g., a temporary folder?==")
  }
  # JAGS:
  #
  # Read data from DIR_NPLCM:
  #
  data_nplcm <- dget(file.path(DIR_NPLCM,"data_nplcm.txt"))  
  model_options <- dget(file.path(DIR_NPLCM,"model_options.txt"))
  mcmc_options <- dget(file.path(DIR_NPLCM,"mcmc_options.txt"))
  if(model_options$likelihood$k_subclass>1){
    is_nested <- TRUE
  } else{
    is_nested <- FALSE
  }
  
  K_curr        <- model_options$likelihood$k_subclass[slice]
  Jcause        <- length(model_options$likelihood$cause_list)
  Nd            <- sum(data_nplcm$Y==1)
  Nu            <- sum(data_nplcm$Y==0)
  
  
  # structure the posterior samples:
  if(is.null(bugs.dat)){
    new_env <- new.env()
    source(file.path(DIR_NPLCM,"jagsdata.txt"),local=new_env)
    bugs.dat <- as.list(new_env)
    rm(new_env)
  } 
  
  
  ncol_dm_FPR <- ncol(bugs.dat[[paste0("Z_FPR_",slice)]]) 
  JBrS        <- 7L #how to get # of measurements
  
  ncol_dm_Eti   <- ncol(bugs.dat$Z_Eti)
  
  templateBS    <- bugs.dat[[paste0("templateBS_",slice)]]
  
  
  if (do_plot){
    cat("==[baker] plotting etiology regression with >>",c("nested", "non-nested")[2-is_nested],"<< model for BrS Measure slice = ",slice,": ",names(data_nplcm$Mobs$MBS)[[slice]]," .==\n")
  }
  
  if (!is.null(RES_NPLCM)){
    res_nplcm <- RES_NPLCM
    
  } else {
    res_nplcm <- coda::read.coda(file.path(DIR_NPLCM,"CODAchain1.txt"),
                                 file.path(DIR_NPLCM,"CODAindex.txt"),
                                 quiet=TRUE)
    
  }
  n_samp_kept   <- nrow(res_nplcm)
  print_res <- function(x) plot(res_nplcm[,grep(x,colnames(res_nplcm))])
  get_res   <- function(x) res_nplcm[,grep(x,colnames(res_nplcm))]
  
  
  #####################################################################
  ## we could require that people have ENRL
  # add x-axis for dates:
  X <- data_nplcm$X
  Y <- data_nplcm$Y
  # some date transformations:
  X$date_plot  <- as.Date(X$ENRLDATE)
  X$date_month_centered <- as.Date(cut(X$date_plot,breaks="2 months"))+30
  X$date_month <- as.Date(cut(X$date_plot,breaks="2 months"))
  
  dd <-  as.Date(X$ENRLDATE)
  min_d <- min(dd)
  min_d_std <- unique(X$std_date[which(as.Date(X$ENRLDATE)==min_d)])
  min_plot_d <- min_d+days_in_month(month(min_d))-day(min_d)+1
  
  max_d <- max(dd)
  max_d_std <- unique(X$std_date[which(as.Date(X$ENRLDATE)==max_d)])
  max_plot_d <- max_d-day(max_d)+1
  plot_d <- seq.Date(min_plot_d,max_plot_d,by = "quarter")
  
  unit_x <- (max_d_std-min_d_std)/as.numeric(max_d-min_d)
  plot_d_std <- as.numeric(plot_d - min_d)*unit_x+min_d_std
  
  pred_d <- seq.Date(min_plot_d,max_plot_d,by = "day")
  pred_d_std <- as.numeric(pred_d - min_d)*unit_x+min_d_std
  #####################################################################
  
  betaFPR_samp <- array(t(get_res(paste0("^betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
  case_betaFPR_samp <- array(t(get_res(paste0("^case_betaFPR_",slice,"\\["))),c(ncol_dm_FPR,K_curr,n_samp_kept))
  betaEti_samp <- array(t(get_res("^betaEti")),c(ncol_dm_Eti,Jcause,n_samp_kept)) #useful in effect estimation.
  ThetaBS_samp <- array(t(get_res(paste0("^ThetaBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
  PsiBS_samp <- array(t(get_res(paste0("^PsiBS_",slice,"\\["))),c(JBrS,K_curr,n_samp_kept))
  Eta_samp <- array(t(get_res(paste0("^Eta_",slice,"\\["))),c(Nd,K_curr,n_samp_kept))
  Lambda_samp <- array(t(get_res(paste0("^Lambda_",slice,"\\["))),c(Nu,K_curr,n_samp_kept))
  subwt_samp <- abind::abind(Eta_samp,Lambda_samp,along=1)
  linpred      <- function(beta,design_matrix){design_matrix%*%beta}
  
  
  out_Eti_linpred     <- array(apply(betaEti_samp,3,linpred,design_matrix=bugs.dat$Z_Eti),
                               c(Nd,Jcause,n_samp_kept)) # can potentially just add pEti to the monitoring.
  
  pEti_samp <-abind::abind(aperm(apply(out_Eti_linpred,c(1,3),softmax),c(2,1,3)),
                           array(0,c(Nu,Jcause,n_samp_kept)),along=1)
  PR_case_ctrl <- compute_marg_PR_nested_reg_array(ThetaBS_array = ThetaBS_samp,PsiBS_array = PsiBS_samp,
                                                   pEti_mat_array = pEti_samp,subwt_mat_array = subwt_samp,
                                                   case = data_nplcm$Y,template = templateBS)
  
  
  
  
  #
  # 2. use this code if date is included in etiology and false positive regressions:
  #
  # false positive rates:
  subset_FPR_ctrl <- data_nplcm$Y==0 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_ctrl <- which(subset_FPR_ctrl)[order(data_nplcm$X$std_date[subset_FPR_ctrl])]
  curr_date_FPR <- data_nplcm$X$std_date[plotid_FPR_ctrl]
  FPR_prob_scale <- PR_case_ctrl[plotid_FPR_ctrl,,]
  
  FPR_mean <- apply(FPR_prob_scale,c(1,2),mean)
  FPR_q    <- apply(FPR_prob_scale,c(1,2),quantile,c(0.025,0.975))
  
  # ^ this could be changed theoretically to not use the observed data -> we could still make TPR/FPR plots with the posterior samples
  
  
  # positive rates for cases:
  fitted_margin_case <- function(pEti_ord,theta,psi,template){
    mixture <-  pEti_ord
    tpr     <-  t(t(template)*theta)
    fpr     <- t(t(1-template)*psi)
    colSums(tpr*mixture + fpr*mixture)
  }
  
  
  Y <- data_nplcm$Y
  subset_FPR_case          <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_FPR_case          <- which(subset_FPR_case)[order(data_nplcm$X$std_date[subset_FPR_case])]
  curr_date_FPR_case       <- data_nplcm$X$std_date[plotid_FPR_case]
  FPR_prob_scale_case      <- PR_case_ctrl[plotid_FPR_case,,]
  
  # etiology:
  subset_Eti <- data_nplcm$Y==1 & stratum_bool # <--- specifies who to look at.
  plotid_Eti <- which(subset_Eti)[order(data_nplcm$X$std_date[subset_Eti])]
  curr_date_Eti  <- data_nplcm$X$std_date[plotid_Eti]
  
  ## compute the probabilities and posterior mean/quantiles
  Eti_prob_scale <- aperm(pEti_samp,c(2,1,3))[,plotid_Eti,]
  Eti_mean <- apply(Eti_prob_scale,c(1,2),mean)
  Eti_q    <- apply(Eti_prob_scale,c(1,2),quantile,c(0.025,0.975))
  Eti_overall <- apply(Eti_prob_scale,c(1,3),mean)
  Eti_overall_mean <- rowMeans(Eti_overall)
  Eti_overall_sd   <- apply(Eti_overall,1,sd)
  Eti_overall_q    <- apply(Eti_overall,1,quantile,c(0.025,0.975))
  
  ## for cases
  PR_case <- PR_case_ctrl[plotid_Eti,,]
  PR_case_mean <- apply(PR_case,c(1,2),mean)
  PR_case_q <- apply(PR_case,c(1,2),quantile,c(0.025,0.975))
  
  ##################
  # plot results:
  #################
  if (do_plot){  
    par(mfcol=c(2,Jcause),oma=c(3,0,3,0))
    for (j in 1:Jcause){ # <--- the marginal dimension of measurements.
      # need to fix this for NoA! <------------------------ FIX!
      #
      # Figure 1 for case and control positive rates:
      #
      par(mar=c(2,5,0,1))
      #<------------------------ FIX!
      if (model_options$likelihood$cause_list[j] == "other"){
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        
        mtext("other",side = 3,cex=1.5,line=1)
      } else if (is.na(match_cause(colnames(data_nplcm$Mobs$MBS[[slice]]),model_options$likelihood$cause_list[j]))) {
        plot(0,0.5,type="l",ylim=c(0,1),pch="n",
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        mtext(model_options$likelihood$cause_list[j],side=3,cex=1.5,line=1)
      } else{                                  #<------------------------ FIX!
        plot(curr_date_FPR,FPR_mean[,j],type="l",ylim=c(0,1),
             xaxt="n",xlab="",ylab=c("","positive rate")[(j==1)+1],las=2,bty="n")
        polygon(c(curr_date_FPR, rev(curr_date_FPR)),
                c(FPR_q[1,,j], rev(FPR_q[2,,j])),
                col = grDevices::rgb(0, 1, 1,0.5),border = NA)
        
        # rug plot:
        if(do_rug){
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==1],side=3,col="dodgerblue2",line=0)
          rug(curr_date_FPR[data_nplcm$Mobs$MBS[[1]][plotid_FPR_ctrl,j]==0],side=1,col="dodgerblue2",line=1)
        }
        
        mtext(names(data_nplcm$Mobs$MBS[[1]])[j],side = 3,cex=1.5,line=1)
        
        points(curr_date_FPR_case,PR_case_mean[,j],type="l",ylim=c(0,1))
        polygon(c(curr_date_FPR_case, rev(curr_date_FPR_case)),
                c(PR_case_q[1,,j], rev(PR_case_q[2,,j])),
                col =  grDevices::rgb(1, 0, 0,0.5),border = NA)
        
        # # make this optional for plotting
        # # rug plot:
        # if(do_rug){
        #   rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==1],side=3,line= 1)
        #   rug(curr_date_FPR_case[data_nplcm$Mobs$MBS[[1]][plotid_FPR_case,j]==0],side=1,line= 0)
        #   
        #   #labels for the rug plot
        if (j==1){
          #     mtext(text = "case   -->",side=2,at=line2user(1,3),cex=0.8,las=1)
          #     mtext(text = "case   -->",side=2,at=line2user(0,1),cex=0.8,las=1)
          #     mtext(text = "control-->",side=2,at=line2user(0,3), cex=0.8,las=1,col="dodgerblue2")
          #     mtext(text = "control-->",side=2,at=line2user(1,1), cex=0.8,las=1,col="dodgerblue2")
          #     
          mtext("1)",side=2,at=0.8,line=3, cex=2,las=1)
        }
        # }
        # 
        
        response.ctrl <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_ctrl,j]
        dat_ctrl <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_ctrl])[!is.na(response.ctrl),,drop=FALSE]
        # dat_ctrl$runmean <- ma_cont(response.ctrl[!is.na(response.ctrl)],dat_ctrl$std_date[!is.na(response.ctrl)])
        
        response.case <- (bugs.dat[[paste0("MBS_",slice)]])[plotid_FPR_case,j]
        dat_case <- data.frame(std_date=data_nplcm$X$std_date[plotid_FPR_case])[!is.na(response.case),,drop=FALSE]
        # dat_case$runmean <- ma_cont(response.case[!is.na(response.case)],dat_case$std_date[!is.na(response.case)])
      }
      #
      # Figure 2 for Etiology Regression:
      #
      par(mar=c(2,5,0,1))
      plot(curr_date_Eti,Eti_mean[j,],type="l",ylim=c(0,1),xlab="standardized date",
           ylab=c("","etiologic fraction")[(j==1)+1],bty="n",xaxt="n",yaxt="n",las=2)
      ## ONLY FOR SIMULATIONS <---------------------- FIX!
      
      # overall pie:
      abline(h=Eti_overall_mean[j],col="black",lwd=2)
      abline(h=Eti_overall_q[,j],col="black",lty=2,lwd=1.5)
      
      mtext(paste0(round(Eti_overall_mean[j],3)*100,"%"),side=3,line=-2,cex=1.2)
      mtext(paste0(round(Eti_overall_q[1,j],3)*100,"%"),side=3,line=-3,cex=0.8,adj=0.15)
      mtext(paste0(round(Eti_overall_q[2,j],3)*100,"%"),side=3,line=-3,cex=0.8,adj=0.85)
      
      if (j==2){
        mtext("<- Overall Pie ->",side=2,at=(line2user(-2,3)+line2user(-1,3))/2,las=1,cex=0.8,col="blue")
        mtext("<- 95% CrI ->",side=2,at=(line2user(-3,3)+line2user(-2,3))/2,las=1,cex=0.8,col="blue")
      }
      
      if (j==1){mtext("2)",side=2,at=0.85,line=3, cex=2,las=1)}
      
      
      color2 <- grDevices::rgb(190, 190, 190, alpha=200, maxColorValue=255)
      color1 <- grDevices::rgb(216,191,216, alpha=200, maxColorValue=255)
      #cases:
      last_interval <- max(X$date_month)
      lubridate::month(last_interval) <- lubridate::month(last_interval) +2
      # axis(1, X$std_date[c(plotid_FPR_case)], 
      #      format(c(X$date_month[c(plotid_FPR_case)]), "%Y %b"), 
      #      cex.axis = .7,las=2,srt=45)
      rle_res <- rle(year(plot_d))
      format_seq <- rep("%b-%d",length(plot_d))
      format_seq[cumsum(c(1,rle_res$lengths[-length(rle_res$lengths)]))] <- "%Y:%b-%d"
      
      axis(1, plot_d_std,
           format(c(plot_d), 
                  format_seq),
           cex.axis = 0.8,las=2)
      
      axis(2,at = seq(0,1,by=0.2),labels=seq(0,1,by=0.2),las=2)
      
      rug(X$std_date[c(plotid_FPR_case)],side=1,line=-0.2,cex=1)
      
      if (j==1){
        mtext(text = "case   -->",side=2,at=line2user(-0.2,1),cex=0.8,las=1)
      }
      
      polygon(c(curr_date_Eti, rev(curr_date_Eti)),
              c(Eti_q[1,j,], rev(Eti_q[2,j,])),
              col = grDevices::rgb(0.5,0.5,0.5,0.5),border = NA)
    }
  }
  if (return_metric){
    return(make_list(Eti_overall_mean,Eti_overall_q,Eti_overall_sd))
  }
}
oslerinhealth/baker documentation built on May 22, 2021, 12:05 p.m.