R/src_summaries.R

Defines functions summary_dpGLM summary_tidy

Documented in summary_tidy

## * tidy summary

#' Tidy summary
#'
#' This function provides a summary of the MCMC samples from the dpGLM model
#'
#' @param object a \code{dpGLM} object returned by the function \code{hdpGLM}
#' @param ... The additional parameters accepted are:
#' 
#'            true.beta: (see \link{plot.dpGLM})
#'
#' @details Data points are assigned to clusters according to the highest estimated probability of belonging to that cluster
#' 
#' @export
summary_tidy <- function(object, ...)
{
    ## if (class(object)=='dpGLM') {
    if (methods::is(object,'dpGLM')) {
        return(summary_dpGLM(object))
    }
    ## if (class(object)=='hdpGLM') {
    if (methods::is(object, 'hdpGLM')) {
        return(summary_hdpGLM(object))
    }
}
summary_dpGLM <- function(object, ...)
{
    x = object
    ## get additional parameters ...
    args = as.list(match.call())
    if(!'true.beta' %in% names(args)) {
        true.beta = NULL
    }else{
        true.beta = eval(args$true.beta)
    }

    only.occupied.clusters = TRUE
    x = dpGLM_get_occupied_clusters(x)

    if(!is.null(true.beta)){
        betas          = hdpGLM_match_clusters(x, true=true.beta)
    }else{
        HPD.lower <- function(x) {return(coda::HPDinterval( coda::as.mcmc(x) )[,"lower"])}
        HPD.upper <- function(x) {return(coda::HPDinterval( coda::as.mcmc(x) )[,"upper"])}
        betas = x$samples %>%
            tibble::as_tibble(.)  %>% 
            tidyr::gather(key=Parameter, value=sample, -k) %>% 
            dplyr::group_by(k, Parameter) %>%
            dplyr::summarize_all(.funs=list(Mean="mean", Median="median", SD="sd", HPD.lower="HPD.lower", HPD.upper="HPD.upper")) %>%
            dplyr::ungroup(.)
    }
    ## include terms column (var names) if not already present
    if (!"term" %in% names(betas)) {
        n.clusters = betas$k %>% unique %>% length
        covariates = attr(x$samples, "terms")
        betas = betas %>% dplyr::left_join(., covariates %>% dplyr::mutate_if(is.factor, as.character), by=c("Parameter")) 
    }
    ## columns to return
    if (is.null(true.beta)) {
        betas = betas %>%
            dplyr::select(k, Parameter, term, dplyr::everything()) 
    }else{
        betas = betas %>%
            dplyr::select(True.Cluster.match, k, Parameter, term, True, Mean, Median, SD, dplyr::contains("HPD"), dplyr::everything()) 
    }
    return(betas)
}
summary_hdpGLM <- function(object, ...)
{

    x = object
    ## get additional parameters ...
    args = as.list(match.call())
    if(!'true.beta' %in% names(args)) {
        true.beta = NULL
    }else{
        true.beta = eval(args$true.beta)
    }
    if(!'true.tau' %in% names(args)) {
        true.tau = NULL
    }else{
        true.tau = eval(args$true.tau)
    }

    Dw = x$context.cov %>% ncol - 1
    Dx = x$samples %>% colnames %>% stringr::str_detect(., pattern="beta") %>% sum - 1
    only.occupied.clusters.in.contexts=TRUE
    ## summarise beta
    ## --------------
    ## Debug/Monitoring message --------------------------
    ## msg <- paste0('\n','Generating summary for beta...',  '\n'); cat(msg)
    ## ---------------------------------------------------
    if(only.occupied.clusters.in.contexts)   x = hdpGLM_get_occupied_clusters(x)
    if(!is.null(true.beta)){
        ## first we need to match the index of the contexts provided by the user and the one used by the algorithm (see details in the function help)
        true.beta      = true.beta  %>% dplyr::full_join(.,  x$context.cov, by=c("j"='C')) 
        ## then for each context, we match the clusters based on smaller distance
        betas          = hdpGLM_match_clusters(x, true=true.beta)
    }else{
        HPD.lower <- function(x) {return(coda::HPDinterval( coda::as.mcmc(x) )[,"lower"])}
        HPD.upper <- function(x) {return(coda::HPDinterval( coda::as.mcmc(x) )[,"upper"])}
        betas = x$samples %>%
            tibble::as_tibble(.)  %>% 
            tidyr::gather(key=Parameter, value=sample, -k, -j) %>% 
            dplyr::group_by(k,j, Parameter) %>%
            dplyr::filter(dplyr::n()>1) %>%  # discard cases with a single draw
            dplyr::summarize_all(.funs=list(Mean="mean", Median="median", SD="sd", HPD.lower="HPD.lower", HPD.upper="HPD.upper")) %>%
            dplyr::ungroup(.)
    }
    ## include terms column (var names) if not already present
    if (!"term" %in% names(betas)) {
        n.clusters = betas$k %>% unique %>% length
        covariates = attr(x$samples, "terms")
        betas = betas %>% dplyr::left_join(., covariates %>% dplyr::mutate_if(is.factor, as.character), by=c("Parameter")) 
    }
    if(is.null(true.beta)){
        betas = betas %>%
            dplyr::select(k, j, Parameter, term, Mean, Median, SD, dplyr::contains("HPD")) 
    }else{
        betas = betas %>%
            dplyr::select(k, j, Parameter, term, True, Mean, Median, SD, dplyr::contains("HPD")) 
    }

    ## summarise tau
    ## -------------
    ## Debug/Monitoring message --------------------------
    ## msg <- paste0('\n','Generating summary for tau...',  '\n'); cat(msg)
    ## ---------------------------------------------------
    tau.summ = x$tau %>% summary(.)
    tau.summ[[1]] = tau.summ[[1]] %>% base::data.frame(Parameter=rownames(.), ., row.names=1:nrow(.)) %>% tibble::as_tibble()
    ## using interval from summary
    tau.summ[[2]] = tau.summ[[2]] %>% base::data.frame(Parameter=rownames(.), ., row.names=1:nrow(.)) %>% tibble::as_tibble() %>% dplyr::select(Parameter, X2.5., X97.5.) %>% dplyr::rename(HPD.lower=X2.5., HPD.upper=X97.5.)
    ## using HPD from coda
    ## tau.summ[[2]] = x$tau %>% coda::HPDinterval(.) %>% base::data.frame(Parameter=rownames(.), ., row.names=1:nrow(.) ) %>% tibble::as_tibble()  %>% dplyr::rename(HPD.lower=lower, HPD.upper=upper) 
    taus =  tau.summ[[1]] %>% 
        dplyr::full_join(., tau.summ[[2]] , by="Parameter") %>% 
        dplyr::select(Parameter, Mean, SD, dplyr::contains("HPD"), -dplyr::contains("Naive"), -dplyr::contains("Time")) %>%
        dplyr::left_join(., attr(x$tau, "terms"), by=c("Parameter")) %>%
        dplyr::mutate(Description = dplyr::case_when(term.tau == '(Intercept)' ~ paste0("Intercept of ", beta) ,
                                                     term.tau != '(Intercept)' ~  paste0("Effect of ", term.tau, " on ", beta)) ) %>%
        dplyr::select(-term.tau, -term.beta) 

    if(!is.null(true.tau)){
        taus = taus %>%
            dplyr::mutate(Parameter=as.character(Parameter)) %>% 
            dplyr::select(-beta)  %>%
            dplyr::full_join(., true.tau  %>% dplyr::mutate(Parameter=as.character(Parameter))  , by=c("Parameter"))  %>%
            dplyr::select(w, beta, Parameter, Description, True, Mean, SD, dplyr::contains("HPD")) %>%
            dplyr::mutate(w = as.integer(w),
                          beta = as.integer(beta))  %>%
            dplyr::arrange(Parameter) 
    }else{
        ## include columns with indexes of beta and W for each tau
        taus = taus %>%
            dplyr::arrange(Parameter)  %>%
            dplyr::mutate(beta = stringr::str_extract(string=Parameter, pattern="\\[[0-9]*\\]$") %>% stringr::str_replace_all(string=., pattern="\\[|\\]", replacement="") %>% as.integer,
                          w    = stringr::str_extract(string=Parameter, pattern="tau\\[[0-9]*\\]") %>% stringr::str_replace_all(string=., pattern="tau|\\[|\\]", replacement="") %>% as.integer)  %>%
            dplyr::select(w, beta, Parameter, Description, dplyr::everything())  %>%
            dplyr::arrange(w, beta) 
    }
    if ('tau.idx' %in% names(taus)) taus = taus %>% dplyr::select(-tau.idx) 
        
    return(list(beta=betas, tau=taus))
}

#' nclusters
#'
#' This function returns the number of clusters found in the estimation
#'
#' @param object a \code{dpGLM} object returned by the function \code{hdpGLM}
#'
#' @export
nclusters <- function(object)
{
    if (methods::is(object,'dpGLM')) {
        nclusters=length(unique(summary_tidy(object)$k))
    }
    if (methods::is(object, 'hdpGLM')) {
        res = (
            summary_tidy(object)$beta
            %>% dplyr::group_by(j)
            %>% dplyr::summarise(nclusters=length(unique(k))) 
            %>% dplyr::rename(Context=j) 
            %>% as.data.frame
        )
        nclusters=res$nclusters
        names(nclusters)= paste('Context', res$Context)
    }
    cat('\n')
    return(nclusters)
}

## * Methods
## ** summary

#' Summary for dpGLM class
#'
#' This function provides a summary of the MCMC samples from the dpGLM model
#'
#' @param object a \code{dpGLM} object returned by the function \code{hdpGLM}
#' @param ... The additional parameters accepted are:
#' 
#'            true.beta: (see \link{plot.dpGLM})
#'
#' @details Data points are assigned to clusters according to the highest estimated probability of belonging to that cluster
#' 
#' @export
summary.dpGLM <- function(object, ...)
{
    print(object)
    cat("\nSummary statistics of clusters with data points\n")

    s = summary_tidy(object)
    k = unique(s$k)
    for (i in 1:length(k))
    {
        cat("\n")
        cat("--------------------------------")
        cat("\n")
        cat(paste("Coefficients for cluster ", i, " (cluster label "
                , k[i], ")", sep=''))
        cat("\n")
        cat("\n")
        stmp = (
            s
            %>% dplyr::filter(k==!!k[i])
            %>% dplyr::select(` `="term",
                              `Post.Mean`="Mean",
                              `Post.Median`="Median",
                              dplyr::contains("HPD"),
                              ## `Cluster label`="k",
                              )
            %>% as.data.frame
        )
        print(stmp)
    }
    cat("\n")
    cat("--------------------------------")
    cat("\n")
    invisible()
}



#' Summary for hdpGLM class
#'
#' This is a generic summary function that describes the output of the function \link{hdpGLM}
#'
#' @param object an object of the class \code{hdpGLM} generted by the function \link{hdpGLM}
#' @param ... Additional arguments accepted are:
#'
#'            \code{true.beta}: a \code{data.frame} with the true values of the linear coefficients \code{beta} if they are known. The \code{data.frame} must contain a column named \code{j} with the index of the context associated with that particular linear coefficient \code{beta}. It must match the indexes used in the data set for each context. Another column named \code{k} must be provided, indicating the cluster of \code{beta}, and a column named \code{Parameter} with the name of the linear coefficients (\code{beta1}, \code{beta2}, ..., \code{beta_dx}, where \code{dx} is the number of covariates at the individual level, and beta1 is the coefficient of the intercept term). It must contain a column named \code{True} with the true value of the \code{betas}. Finally, the \code{data.frame} must contain columns with the context-level covariates as used in the estimation of the \link{hdpGLM} function (see Details below).
#' 
#'            \code{true.tau}: a \code{data.frame} with four columns. The first must be named \code{w} and it indicates the index of each context-level covariate, starting with 0 for the intercept term. The second column named \code{beta} must contain the indexes of the betas of individual-level covariates, starting with 0 for the intercept term. The third column named \code{Parameter} must be named \code{tau<w><beta>}, where \code{w} and \code{beta} must be the actual values displayed in the columns \code{w} and \code{beta}. Finally, it must have a column named \code{True} with the true value of the parameter.
#'
#' @return The function returns a list with two data.frames. The first summarizes the posterior distribution of the linear coefficients \code{beta}. The mean, median, and the 95\% HPD interval are provided. The second data.frame contains the summary of the posterior distribution of the parameter \code{tau}.
#'
#' @details The function hdpGLM returns a list with the samples from the posterior distribution along with other elements. That list contains an element named \code{context.cov} that connects the indexed "C" created during the estimation and the context-level covariates. So each unique context-level covariate gets an index during the estimation. The algorithm only requires the context-level covariates, but it creates such index C to help the estimation. If true.beta is provided, it must contain indexes for the context as well, which indicates the context of each specific linear coefficient \code{beta}. Such index will probably be different from the one created by the algorithm. Therefore, when the \code{true.beta} is provided, we need to connect the context index C generated by the algorithm and the column j in the true.beta data.frame in order to compare the true values and the estimated value for each context. That is why we need the values of the context-level covariates as well. The summary uses them as key to merge the true and the estimated values for each context. The true and estimated clusters are matched based on the shortest distance between the estimated posterior average and the true value in each context because the labels of the clusters in the estimation can vary, even thought the same data points are classified in the same clusters.
#' 
#' @export
summary.hdpGLM <- function(object, ...)
{
    print(object)
    summ = summary_tidy(object)
    s = summ[['beta']] %>% dplyr::filter(Parameter!='sigma') 
    J = unique(s$j)
    cat("\n")
    cat("\nSummary statistics of clusters with data points in each context\n")
    for (j in J)
    {
        cat("\n")
        cat("--------------------------------")
        cat("\n")
        cat(paste("Coefficients and clusters for context ", j, sep=''))
        cat("\n")
        cat("\n")
        stmp = (
            s
            %>% dplyr::filter(j==!!j) 
            %>% dplyr::select(` `="term",
                              `Post.Mean`="Mean",
                              `Post.Median`="Median",
                              dplyr::contains("HPD"),
                              `Cluster`="k",
                              )
            %>% as.data.frame
        )
        print(stmp)
    }
    ## Context-level effect
    ## --------------------
    cat("\n")
    cat("--------------------------------")
    cat("\n")
    cat("Context-level coefficients:")
    stau = summ[['tau']] %>% dplyr::filter(Parameter!='sigma') 
    tmp = (
        stau
        %>% dplyr::select( "Description", `Post.Mean`="Mean",
                          dplyr::contains("HPD")
                          )
        %>% dplyr::mutate(Description = stringr::str_replace(string=Description,
                                                             pattern="W",
                                                             replacement="context term ")) 
    )
    cat('\n')
    print(as.data.frame(tmp))
    cat("\n")
    cat("--------------------------------")
    cat("\n")
    invisible()
}


## ** plots

#' Default plot for class dpGLM
#'
#' This function generates desity plots with the posterior distribution generated by the function \code{\link{hdpGLM}}
#'
#' @param x a dpGLM object with the samples from generated by \code{\link{hdpGLM}}
#' @param terms string vector with the name of covariates to plot. If \code{NULL} (default), all covariates are plotted.
#' @param separate boolean, if \code{TRUE} the linear coefficients \code{beta} will be displayed in their separate clusters. 
#' @param hpd boolean, if \code{TRUE} and \code{separate=T}, the 95\% HPDI lines will be displayed.
#' @param true.beta either \code{NULL} (default) or a \code{data.frame} with the true values of the linear coefficients \code{beta} if they are known. The \code{data.frame} must contain a column named \code{k} indicating the cluster of \code{beta}, and a column named \code{Parameter} with the name of the linear coefficients (\code{beta1}, \code{beta2}, ..., \code{beta_dx}, where \code{dx} is the number of covariates at the individual level, and beta1 is the coefficient of the intercept term). It must contain a column named \code{True} with the true value of the \code{betas}. 
#' @param only.occupied.clusters boolean, if \code{TRUE} it shows only the densities of the clusters that actually have data points assigned to it with high probability
#' @param title  string, the title of the plot
#' @param subtitle  string, the subtitle of the plot
#' @param focus.hpd boolean, if \code{TRUE} and separate is also \code{TRUE} it will display only the 95\% HPDI of the posterior density of the linear coefficients \code{beta}
#' @param ncols integer, the number of columns in the plot 
#' @param legend.position one of four options: "bottom" (default), "top", "left", or "right". It indicates the position of the legend
#' @inheritParams graphics::par
#' @inheritParams stats::density
#' @param colour = string with color to fill the density plot
#' @param alpha number between 0 and 1 indicating the degree of transparency of the density 
#' @param display.terms boolean, if \code{TRUE} (default), the covariate name is displayed in the plot
#' @param plot.mean boolean, if \code{TRUE} the posterior mean of every cluster is displayed
#' @param legend.label.true.value a string with the value to display in the legend when the \code{true.beta} is used 
#' @param ... ignored 
#'
#'
#' @examples
#' # Note: this example is just for illustration. MCMC iterations are very reduced
#' set.seed(10)
#' n = 20
#' data = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:3, n, replace=TRUE),
#'                                    y  =I(z==1) * (3 + 4*x1 - x2 + rnorm(n)) +
#'                                        I(z==2) * (3 + 2*x1 + x2 + rnorm(n)) +
#'                                        I(z==3) * (3 - 4*x1 - x2 + rnorm(n)) ,
#'                                    ) 
#' 
#' ## estimation
#' mcmc    = list(burn.in=1, n.iter=50)
#' samples = hdpGLM(y ~ x1 + x2,  data=data, mcmc=mcmc, n.display=1)
#' 
#' plot(samples)
#' 
#' 
#' @export
plot.dpGLM    <- function(x, terms=NULL, separate=FALSE, hpd=TRUE,
                          true.beta=NULL, title=NULL, subtitle=NULL, adjust=1,
                          ncols=NULL, only.occupied.clusters=TRUE,
                          focus.hpd=FALSE, legend.position="top", colour='grey',
                          alpha=.4, display.terms=TRUE, plot.mean=TRUE,
                          legend.label.true.value="True", ...)
{
    ## keep all default options
    op.default <- options()
    on.exit(options(op.default), add=TRUE)
    ## keep current working folder on exit
    dir.default <- getwd()
    on.exit(setwd(dir.default), add=TRUE)
    ## no warning messages
    ## options(warn=-1)
    ## on.exit(options(warn=0))

    ## Debug/Monitoring message --------------------------
    msg <- paste0('\n','\nGenerating plot...\n',  '\n'); cat(msg)
    ## ---------------------------------------------------

    x = dpGLM_get_occupied_clusters(x)
    tab = x$samples %>%
        tibble::as_tibble(.) %>%
        dplyr::select(-dplyr::contains("sigma"))  %>%
        tidyr::gather(key = Parameter, value=values, -k) %>%
        dplyr::left_join(., summary_tidy(x) %>% dplyr::select(k, Parameter, term, Mean, dplyr::contains("HPD")) , by=c("k", "Parameter"))  %>% 
        ## dplyr::full_join(., summary(x) %>% dplyr::select(term, Parameter) %>% dplyr::filter(Parameter!='sigma')   , by=c('Parameter'))  %>% 
        dplyr::mutate(Parameter = paste0(stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']'),
                      k = paste0("Cluster~", k, sep='')) 
    if (!is.null(terms)) 
        tab = tab %>%
            dplyr::filter(term %in% terms)
    ## %>% dplyr::rename('Cluster' = 'k') 
    if (!is.null(true.beta)) {
        true =  hdpGLM_match_clusters(x, true=true.beta) %>%
            dplyr::mutate(Parameter = paste0(stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']'),
                          Cluster = paste0("Cluster~", k, sep=''))
        tab = tab %>%
            dplyr::left_join(., true  %>% dplyr::select(-dplyr::contains("HPD"), -dplyr::contains("Mean")) , by=c("Parameter", "k"="Cluster"))
    }
    if (focus.hpd) {
        tab = tab %>%
            dplyr::filter(HPD.lower <= values & values <= HPD.upper) 
    }
    if(display.terms)
        tab = tab %>% 
            dplyr::mutate(term = gsub("\\(", "", term),
                          term = gsub(")", "", term),
                          term = gsub(" ", "\\~", term),
                          term = paste0("(",term,")") 
                          )  %>% 
            tidyr::unite(., Parameter, Parameter, term, sep='~')
    g = tab %>%
        ggplot2::ggplot(.) +
        ## geom_line(aes(x=values, group=Parameter), colour="#00000044", stat='density', alpha=1) +
        ## ggplot2::geom_density(ggplot2::aes(x=values, group=Parameter), fill="#00000028", adjust=adjust) +
        ggplot2::geom_density(ggplot2::aes(x=values, group=Parameter), fill=colour, adjust=adjust, alpha=alpha, colour='white') +
        ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='solid', col="black")) +
        ## geom_vline(data=tab, aes(xintercept=HPD.lower,  linetype='dashed'), col="black") +
        ## geom_vline(data=tab, aes(xintercept=HPD.upper,  linetype='dashed'), col="black") +
        ## scale_linetype_manual(values=c('solid', 'dashed'), labels=c("True", "95% HPDI"), name='') +
        ## ggplot2::scale_colour_manual(values = "red", name='', labels='True') +
        ggplot2::ylab('Density') +
        ggplot2::theme_bw()+
        ggplot2::scale_x_continuous(expand = c(0, 0)) +
        ggplot2::scale_y_continuous(expand = c(0, 0)) +
        ggplot2::theme(strip.background = ggplot2::element_rect(colour="white", fill="white"),
                       strip.text.x = ggplot2::element_text(size=12, face='bold', hjust=0),
                       strip.text.y = ggplot2::element_text(size=12, face="bold", vjust=0)) +
        ggplot2::theme(legend.position = legend.position) 

    if (!is.null(true.beta)) {
        if( plot.mean & separate){
            g = g +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c('dashed', "solid", "solid"), name='', labels=c( "95% HPD", "Mean", "True"))+
                ggplot2::scale_color_manual   (values=c('black', "black", "red"), name='', labels=c( "95% HPD", "Mean", "True")) +
                ggplot2::facet_wrap(~  k + Parameter , ncol=ncols,  scales='free', labeller=ggplot2::label_parsed) 
        }
        if( !plot.mean & separate){
            g = g +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c('dashed',  "solid"), name='', labels=c( "95% HPD",  "True"))+
                ggplot2::scale_color_manual   (values=c('black',  "red"), name='', labels=c( "95% HPD", "True")) +
                ggplot2::facet_wrap(~  k + Parameter , ncol=ncols,  scales='free', labeller=ggplot2::label_parsed) 
        } 
        if( plot.mean & !separate){
            g = g +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c('solid',  "solid"), name='', labels=c( "Cluster Mean",  legend.label.true.value))+
                ggplot2::scale_color_manual   (values=c('black',  "red"), name='', labels=c( "Cluster Mean", legend.label.true.value)) +
                ggplot2::facet_wrap( ~ Parameter, ncol = ncols, scales='free', labeller=ggplot2::label_parsed)
        } 
        if( !plot.mean & !separate){
            g = g +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c(  "solid"), name='', labels=c(legend.label.true.value))+
                ggplot2::scale_color_manual   (values=c(  "red"), name='', labels=c(  legend.label.true.value)) +
                ggplot2::facet_wrap( ~ Parameter, ncol = ncols, scales='free', labeller=ggplot2::label_parsed)
        } 
    }else{## no true.beta
        if( plot.mean & separate){
            g = g +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c('dashed', "solid"), name='', labels=c( "95% HPD", "Mean"))+
                ggplot2::scale_color_manual   (values=c('black', "black"), name='', labels=c( "95% HPD", "Mean")) +
                ggplot2::facet_wrap(~  k + Parameter , ncol=ncols,  scales='free', labeller=ggplot2::label_parsed) 
        }
        if( !plot.mean & separate){
            g = g +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c('dashed'  ), name='', labels=c( "95% HPD"  ))+
                ggplot2::scale_color_manual   (values=c('black' ), name='', labels=c( "95% HPD" )) +
                ggplot2::facet_wrap(~  k + Parameter , ncol=ncols,  scales='free', labeller=ggplot2::label_parsed) 
        } 
        if( plot.mean & !separate){
            g = g +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ggplot2::scale_linetype_manual(values=c('solid'), name='', labels=c( "Cluster Mean"))+
                ggplot2::scale_color_manual   (values=c('black'), name='', labels=c( "Cluster Mean")) +
                ggplot2::facet_wrap( ~ Parameter, ncol = ncols, scales='free', labeller=ggplot2::label_parsed)
        } 
        if( !plot.mean & !separate){
            g = g +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=True, col="True",  linetype='True')) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=Mean,  linetype='Mean', col="Mean")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.lower,  linetype='95% HPD', col="95% HPD")) +
                ## ggplot2::geom_vline(data=tab, ggplot2::aes(xintercept=HPD.upper,  linetype='95% HPD', col="95% HPD"))  +
                ## ggplot2::scale_linetype_manual(values=c(  "solid"), name='', labels=c(   "True"))+
                ## ggplot2::scale_color_manual   (values=c(  "red"), name='', labels=c(  "True")) +
                ggplot2::facet_wrap( ~ Parameter, ncol = ncols, scales='free', labeller=ggplot2::label_parsed)
        } 
    }
    ## title and subtitle
    if (!is.null(title)) {
        if (!is.null(subtitle)) {
            g = g + ggplot2::ggtitle(title, subtitle)
        }else{
            g = g + ggplot2::ggtitle(title)
        }
    }
    return(g)
}

#' Plot
#'
#' Generic function to plot the posterior density estimation produced by the function \code{hdpGLM}
#'
#' @param x an object of the class \code{hdpGLM} generted by the function \link{hdpGLM}
#' @param terms string vector with the name of the individual-level covariates to plot. If \code{NULL} (default), all covariates are plotted.
#' @param j.idx integer vector with the index of the contexts to plot. An alternative is to use the context labels with the parameter \code{j.label} instead of the indexes. If \code{NULL} (default) and j.label is also \code{NULL}, the posterior distribution of all contexts are plotted
#' @param j.label string vector with the names of the contexts to plot. An alternative is to use the context indexes with the parameter \code{j.idx} instead of the context labels. If \code{NULL} (default) and j.idx is also \code{NULL}, the posterior distribution of all contexts are plotted. Note: if contexts to plot are selected using \code{j.label}, the parameter \code{context.id} must also be provided.
#' @param title  string, the title of the plot
#' @param subtitle  string, the subtitle of the plot
#' @param true.beta a \code{data.frame} with the true values of the linear coefficients \code{beta} if they are known. The \code{data.frame} must contain a column named \code{j} with the index of the context associated with that particular linear coefficient \code{beta}. It must match the indexes used in the data set for each context. Another column named \code{k} must be provided, indicating the cluster of \code{beta}, and a column named \code{Parameter} with the name of the linear coefficients (\code{beta1}, \code{beta2}, ..., \code{beta_dx}, where \code{dx} is the number of covariates at the individual level, and beta1 is the coefficient of the intercept term). It must contain a column named \code{True} with the true value of the \code{betas}. Finally, the \code{data.frame} must contain columns with the context-level covariates as used in the estimation of the \link{hdpGLM} function (see Details below).
#' @param ncol interger, the number of columns in the plot
#' @param legend.position one of four options: "bottom" (default), "top", "left", or "right". It indicates the position of the legend
#' @param display.terms boolean, if \code{TRUE} (default), the covariate name is displayed in the plot
#' @param context.id string with the name of the column containing the labels identifying the contexts. This variable should have been specified when the estimation was conducted using the function \code{\link{hdpGLM}}.
#' @param ylab string, the label of the y-axis
#' @param xlab string, the label of the x-axis
#' @param rel.height see ggridges::geom_density_ridges
#' @param x.axis.size numeric, the relative size of the label in the x-axis 
#' @param y.axis.size numeric, the relative size of the label in the y-axis 
#' @param title.size numeric, the relative size of the title of the plot
#' @param panel.title.size numeric, the relative size of the titles in the panel of the plot
#' @param legend.size numeric, the relative size of the legend
#' @param fill.col string with the color of the densities
#' @param border.col string with the color of the border of the densities
#' @inheritParams summary.hdpGLM
#' 
#' @export
plot.hdpGLM <- function(x, terms=NULL, j.label=NULL, j.idx=NULL, title=NULL,
                        subtitle=NULL, true.beta=NULL, ncol=NULL,
                        legend.position="bottom", display.terms=TRUE,
                        context.id=NULL, ylab=NULL, xlab=NULL, x.axis.size=1.1,
                        y.axis.size=1.1, title.size=1.2, panel.title.size=1.5,
                        legend.size=1.1, rel.height=0.01, fill.col="#00000044",
                        border.col='white', ...)
{
    if (is.null(j.idx) & is.null(j.label)) {
        j = x$context.cov$C 
    }else{
        if (!is.null(j.idx)) {
            j = j.idx
        }
        if (!is.null(j.label)) {
            if (!is.null(context.id)) {
                j = x$context.cov %>% dplyr::filter(x$context.cov[,context.id] %>% dplyr::pull(.) %in% j.label)  %>% dplyr::select(C)  %>% dplyr::pull(.)
            }else{
                stop("\n\nThe \'context.id\' must also be provided if \'j.label\' is used.\n\n")
            }
        }
    }
    x = hdpGLM_get_occupied_clusters(x)
    summ = summary_tidy(x)
    if (!is.null(true.beta)) {
        ## first we need to match the index of the contexts provided by the user and the one used by the algorithm (see details in the function help)
        true.beta      = true.beta  %>% dplyr::full_join(.,  x$context.cov, by=c("j"="C"))
        ## then for each context, we match the clusters based on smaller distance and prepare the table for plotting
        tab =  hdpGLM_match_clusters(x, true=true.beta) %>%
            dplyr::mutate(jnext = j+1) %>%
            dplyr::mutate(Parameter = paste0(stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']'))
        xlim = tab %>%
            dplyr::summarize(lower=min(HPD.lower),
                             upper=max(HPD.upper)) %>%
            c %>%
            unlist
        tab2 = x$samples %>%
            tibble::as_tibble(.)  %>%
            dplyr::select(-sigma)  %>% 
            tidyr::gather(key = Parameter, value=values, -j, -k) %>%
            dplyr::full_join(., summ$beta %>% dplyr::select(term, Parameter) %>% dplyr::filter(Parameter!='sigma')   , by=c('Parameter'))  %>% 
            dplyr::mutate(Parameter = paste0(stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']'))
        if (!is.null(terms)) {
            tab2 = tab2 %>%
                dplyr::filter(term %in% terms)
            tab = tab %>%
                dplyr::filter(term %in% terms)
            }
        if(display.terms)
        {           
            tab2 = tab2 %>% 
                dplyr::mutate(term = gsub("\\(", "", term),
                              term = gsub(")", "", term),
                              term = gsub(" ", "\\~", term),
                              term = paste0("(",term,")") 
                              )  %>% 
                tidyr::unite(., Parameter, Parameter, term, sep='~')
            tab = tab %>%
                dplyr::mutate(term = gsub("\\(", "", term),
                              term = gsub(")", "", term),
                              term = gsub(" ", "\\~", term),
                              term = paste0("(",term,")") 
                              )  %>% 
                tidyr::unite(., Parameter, Parameter, term, sep='~')
        }        
        g = tab2 %>%
            dplyr::filter(j %in% !!j)  %>% 
            dplyr::mutate_if(is.factor, as.character) %>% 
            ggplot2::ggplot(.) +
            ## ggjoy::geom_joy(ggplot2::aes(x=values, y=j, group=j), fill="#00000044") +
            ggridges::geom_density_ridges(ggplot2::aes(x=values, y=j, group=j), fill=fill.col, colour=border.col, rel_min_height = rel.height) +
            ggplot2::geom_segment(data=tab  %>%
                                      dplyr::filter(j %in% !!j) %>%
                                      dplyr::mutate_if(is.factor, as.character),
                                  ggplot2::aes(x=True, xend=True, y=j, yend=jnext, col='red')) +
            ## ggplot2::geom_segment(data=tab , ggplot2::aes(x=Mean, xend=Mean, y=j, yend=jnext, col='black')) +
            ggplot2::facet_wrap( ~ Parameter, ncol = ncol, scales='free', labeller=ggplot2::label_parsed) +
            ggplot2::ylab('Context Index') +
            ggplot2::theme_bw()+
            ggplot2::theme(strip.background = ggplot2::element_rect(colour="white", fill="white"),
                           strip.text.x = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face='bold', hjust=0),
                           strip.text.y = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face="bold", vjust=0)) +
            ggplot2::scale_colour_manual(values = c("red", "black"), name="", labels=c('True', "MCMC Cluster Mean")) +
            ggplot2::theme(legend.position = legend.position,
                           legend.text     = ggplot2::element_text(size=ggplot2::rel(legend.size))) +
            ggplot2::xlim(xlim)
    }else{
        xlim = summ$beta %>% dplyr::summarize(lower=min(HPD.lower), upper=max(HPD.upper)) %>% c %>% unlist
        tab = x$samples %>%
            tibble::as_tibble(.)  %>%
            dplyr::select(-sigma)  %>% 
            tidyr::gather(key = Parameter, value=values, -j, -k) %>%
            dplyr::full_join(., summ$beta %>% dplyr::select(term, Parameter) %>% dplyr::filter(Parameter!='sigma')   , by=c('Parameter'))  %>% 
            dplyr::mutate(Parameter = paste0(stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']')) 
        if (!is.null(terms)) 
            tab = tab %>%
                dplyr::filter(term %in% terms)
        if(display.terms)
            tab = tab %>% 
                dplyr::mutate(term = gsub("\\(", "", term),
                              term = gsub(")", "", term),
                              term = gsub(" ", "\\~", term),
                              term = paste0("(",term,")") 
                              )  %>% 
                tidyr::unite(., Parameter, Parameter, term, sep='~')
        g = tab %>%
            dplyr::filter(j %in% !!j)  %>% 
            dplyr::mutate_if(is.factor, as.character) %>% 
            ggplot2::ggplot(.) +
            ## ggjoy::geom_joy(ggplot2::aes(x=values, y=j, group=j), fill="#00000044") +
            ggridges::geom_density_ridges(ggplot2::aes(x=values, y=j, group=j), fill=fill.col, colour=border.col, rel_min_height = rel.height) +
            ggplot2::ylab('Context Index') +
            ggplot2::facet_wrap( ~ Parameter, ncol = ncol, scales='free', labeller=ggplot2::label_parsed) +
            ggplot2::theme_bw()+
            ggplot2::theme(strip.background = ggplot2::element_rect(colour="white", fill="white"),
                           strip.text.x = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face='bold', hjust=0),
                           strip.text.y = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face="bold", vjust=0)) +
            ggplot2::xlim(xlim)

    }
    ## title and subtitle
    if (!is.null(title)) {
        if (!is.null(subtitle)) {
            g = g + ggplot2::ggtitle(title, subtitle)
        }else{
            g = g + ggplot2::ggtitle(title)
        }
    }
    ## y-axis
    ## ------
    contexts = x$context.cov %>% dplyr::filter(C %in% !!j)  %>% dplyr::mutate_if(is.factor, as.character)
    if (!is.null(context.id)) {
        contexts = x$context.cov %>% dplyr::filter(C %in% !!j)  %>%
            dplyr::mutate_if(is.factor, as.character)
        g = g +
            ggplot2::scale_y_discrete(breaks = contexts$C, labels = contexts[,context.id] %>% dplyr::pull(.) %>% as.character, limits=contexts$C) +
            ggplot2::ylab(context.id) 
    }else{
        g = g +
            ggplot2::scale_y_discrete(breaks = contexts$C, labels = contexts$C %>% as.character, limits=contexts$C)
    }
    if (!is.null(ylab)) {
        g = g + ggplot2::ylab(ylab) 
    }
    if (!is.null(xlab)) {
        g = g + ggplot2::xlab(xlab)
    }
    g = g + ggplot2::theme(axis.title.x = ggplot2::element_text(size=ggplot2::rel(x.axis.size), angle=0, ),
                           axis.title.y = ggplot2::element_text(size=ggplot2::rel(y.axis.size)),
                           plot.title   = ggplot2::element_text(size=ggplot2::rel(title.size)))
    
    return(g)
}

## ** predict

#' dpGLM Predicted values
#'
#' Function returns the predicted (fitted) values of the outcome variable using
#' the estimated posterior expectation of the linear covariate betas produced by
#' the \code{hdpGLM} function
#'
#'
#' @param object outcome of the function hdpLGM
#' @param new_data data frame with the values of the covariates that are going
#' to be used to generate the predicted/fitted values. The posterior mean is
#' used to create the predicted values 
#' @param ... 
#' 
#'
#'        \code{family} : a string with the family of the output variable:
#' \code{gaussian} (default), \code{binomial}, etc...
#' 
#' @return It returns a data.frame with the fitted values for the outcome
#' variable, which are produced using the estimated posterior expectation of the
#' linear coefficients \code{beta}.
#'
#' @export
predict.dpGLM <- function(object, new_data=NULL, ...)
## predict.dpGLM  <- function(samples, data)
{
    ## options(warn=-1)
    ## on.exit(options(warn=0))

    samples = object
    if(is.null(new_data))
    {
        data = samples$data
    }else{
        data = new_data
    }
        
    
    ## get data design matrix
    formula = attr(samples, 'formula1')
    X = getRegMatrix_main(formula, data)$X

    ## get linear coefficients
    K = ncol(samples$pik)
    betas = tibble::tibble(k = 1:K, pk = samples$samples_pi %>% colMeans) %>% 
        dplyr::full_join(., samples %>% summary_tidy(., only.occupied.clusters=TRUE) , by=c("k"))  %>%
        tidyr::drop_na() 
    Ks = betas$k %>% unique()
    yhat = matrix(rep(0, times=nrow(X)), ncol=1)
    covars = colnames(X)
    for (k in Ks)
    {
        pi = betas %>% 
            dplyr::filter(k=={{k}})  %>%
            dplyr::select(pk)  %>%
            dplyr::distinct(., .keep_all=TRUE)  %>%
            dplyr::pull(.)
        beta = betas %>%
            dplyr::filter(k=={{k}})  %>%
            dplyr::select(term, Mean)   %>%
            tidyr::spread(., key=term, value=Mean) %>%
            dplyr::select(covars)  %>%
            as.matrix %>%
            t
        yhat = yhat + (pi * (X %*% beta))
    }
    return(yhat)
}


getRegMatrix_main <- function(formula, data)
{
    ## check if output variable (y) is missing
    cols = names(data)
    output_var = formula.tools::lhs.vars(formula)
    if (! output_var %in% names(data)) {
        data[,output_var] = 1
    }
    ## get design matrix
    func.call <- match.call(expand.dots = FALSE)
    X = .getRegMatrix(func.call, data, weights=NULL, formula_number='')
    return(X)
}



#' hdpGLM Predicted values
#'
#' Function returns the predicted (fitted) values of the outcome variable using
#' the estimated posterior expectation of the linear covariate betas produced by
#' the \code{hdpGLM} function
#'
#'
#' @param object outcome of the function hdpLGM
#' @param new_data data frame with the values of the covariates that are going
#' to be used to generate the predicted/fitted values. The posterior mean is
#' used to create the predicted values 
#' @param ... 
#' 
#'
#'        \code{family} : a string with the family of the output variable:
#' \code{gaussian} (default), \code{binomial}, etc...
#' 
#' @return It returns a data.frame with the fitted values for the outcome
#' variable, which are produced using the estimated posterior expectation of the
#' linear coefficients \code{beta}.
#'
#' @export
predict.hdpGLM <- function(object, new_data=NULL, ...)
## predict.dpGLM  <- function(samples, data)
{
    ## options(warn=-1)
    ## on.exit(options(warn=0))

    samples = object
    if(is.null(new_data))
    {
        data = samples$data
    }else{
        data = new_data
    }
        
    
    ## get data design matrix
    formula = attr(samples, 'formula1')
    X = getRegMatrix_main(formula, data)$X

    ## get linear coefficients
    K = nrow(samples$sample_pi_postMean)
    J = ncol(samples$sample_pi_postMean)
    tab_pi=(
        samples$sample_pi_postMean
        %>% as.data.frame() 
        %>% tibble::as_tibble() 
        %>% dplyr::mutate(k=1:K) 
        %>% tidyr::pivot_longer(cols=paste0(1:J), names_to = 'j',
                                values_to = 'pi')
        %>% dplyr::mutate(j = as.numeric(as.character(j))) 
    )

    betas = (
        tab_pi
        %>% dplyr::full_join(.,samples
                             %>% summary_tidy(., only.occupied.clusters=TRUE)
                             %>% .$beta,
                             by=c("k", 'j'))
        %>% tidyr::drop_na()   
    )
    
    Js = betas$j %>% unique()
    yhat = matrix(rep(0, times=nrow(X)), ncol=1)
    covars = colnames(X)
    for (j in Js)
    {
        contextj = samples$context.index==j
        Ks = betas %>% dplyr::filter(j=={{j}}) %>% dplyr::pull(k)  %>% unique()
        for (k in Ks)
        {
            pi = betas %>% 
                dplyr::filter(k=={{k}} & j=={{j}})  %>%
                dplyr::select(pi)  %>%
                dplyr::distinct(., .keep_all=TRUE)  %>%
                dplyr::pull(.)
            beta = betas %>%
                dplyr::filter(k=={{k}} & j=={{j}})  %>%
                dplyr::select(term, Mean)   %>%
                tidyr::spread(., key=term, value=Mean) %>%
                dplyr::select(covars)  %>%
                as.matrix %>%
                t
            yhat[contextj] = yhat[contextj] + (pi * (X[contextj,] %*% beta))
        }
    }
    return(yhat)
}

## ** print


#' Print
#'
#' Generic method to print the output of the \code{dpGLM} function
#'
#' @param x a \code{dpGLM} object returned by the function \code{hdpGLM}
#' @param ... ignore
#'
#' @return returns a summary of the posterior distribution of the parameters
#'
#' @export
print.dpGLM <- function(x, ...)
{
    cat(" \n")
    cat("--------------------------------")
    cat(" \n")
    cat(paste0("dpGLM model object"))
    cat("\n")
    cat(paste0("\nMaximum number of clusters activated during the estimation: ", x$max_active, sep=''))
    cat(paste0("\nNumber of MCMC iterations: ", x$n.iter, sep=''))
    cat(paste0("\nburn-in: ", x$burn.in, "\n", sep=''))
    ## s = s[,-ncol(s)]
    cat("--------------------------------")
    cat(" \n")
    invisible()
}

#' Print
#'
#' Generic method to print the output of the \code{hdpGLM} function
#'
#' @param x a \code{hdpGLM} object returned by the function \code{hdpGLM}
#' @param ... ignore 
#'
#' @return returns a summary of the posterior distribution of the parameters
#'
#' @export
print.hdpGLM <- function(x, ...)
{
    s = summary_tidy(x)[['beta']]
    s = s %>% dplyr::filter(Parameter!='sigma') 
    J = unique(s$j)
    ## Summary
    ## -------
    summ_nclusters = (
        s
        %>% dplyr::group_by(j)
        %>% dplyr::summarise(nclusters = length(unique(k))) 
    )
    context_min = (
        summ_nclusters
        %>% dplyr::filter(nclusters==min(nclusters))  
        %>% dplyr::pull(j)
    )
    context_max = (
        summ_nclusters
        %>% dplyr::filter(nclusters==max(nclusters))  
        %>% dplyr::pull(j)
    )
    cat(" \n")
    cat("--------------------------------")
    cat(" \n")
    cat("hdpGLM Object")
    cat(" \n")
    summ_nclusters =  summ_nclusters  %>% dplyr::pull(nclusters)
    cat(paste0("\nMaximum number of clusters activated during the estimation: ",
               x$max_active, sep=''))
    cat(paste0("\nNumber of MCMC iterations: ", x$n.iter, sep=''))
    cat(paste0("\nBurn-in: ", x$burn.in, "\n", sep=''))
    cat("\n")
    cat('Number of contexts : ')
    cat(length(J))
    cat("\n")
    cat("\n")
    cat('Number of clusters (summary across contexts): ')
    cat("\n")
    cat("\n")
    tmp = data.frame(Average=mean(summ_nclusters),
               Std.Dev=sd(summ_nclusters),
               Median=median(summ_nclusters),
               Min.=min(summ_nclusters), 
               Max.=max(summ_nclusters)
               )
    print(tmp)
    cat("--------------------------------")
    cat(" \n")
    invisible()
}

## ** coef


#' Extract dpGLM fitted coefficients
#'
#' This function gives the posterior mean of the coefficients
#'
#' @param object a \code{dpGLM} object returned by the function \code{hdpGLM}
#' @param ... The additional parameters accepted are:
#' 
#'
#' 
#' @export
coef.dpGLM <- function(object, ...)
{
    s = summary_tidy(object)
    coefs = s$Mean
    cluster = paste('(Cluster ', s$k, ')', sep='')
    names(coefs) = paste(s$term, cluster)
    cat("\n")
    cat("Note: Coefficients are the posterior mean")
    cat("\n")
    cat("\n")
    return(coefs)
}


#' Extract hdpGLM fitted coefficients
#'
#' This function gives the posterior mean of the coefficients
#'
#' @param object a \code{dpGLM} object returned by the function \code{hdpGLM}
#' @param ... The additional parameters accepted are:
#' 
#'
#' 
#' @export
coef.hdpGLM <- function(object, ...)
{
    summ =  summary_tidy(object)
    s = summ$beta
    coefs = s$Mean
    cluster_context = paste('(Cluster ', s$k, ', context',s$j,')', sep='')
    names(coefs) = paste(s$term, cluster_context)
    ##
    s =  summ$tau
    coefs_tau = s$Mean
    names(coefs_tau) = paste(s$Description)
    res = list('beta'=coefs, 'tau'=coefs_tau)
    cat("\n")
    cat("Note: Coefficients are the posterior mean")
    cat("\n")
    cat("\n")
    return(res)
}

## ** mcmc

#' mcmc
#'
#' Generic method to return the MCMC information
#'
#' @param x a \code{dpGLM} object returned by the function
#'          \code{hdpGLM}
#' @param ... ignore 
#'
#'
#' @export
mcmc_info.dpGLM = function(x, ...)
{
    cat("\n\n")
    cat(paste0("Number of iterations: ", x$n.iter))
    cat("\n")
    cat(paste0("Burn-in period: ", x$burn.in))
    cat("\n\n")
    cat("Note: for diagnostics, use the package coda or MCMCpack.")
    cat("\n\n")
    invisible()
}



#' mcmc
#'
#' Generic method to return the MCMC information
#'
#' @param x a \code{hdpGLM} object returned by the function
#'          \code{hdpGLM}
#' @param ... ignore 
#'
#'
#' @export
mcmc_info.hdpGLM = function(x, ...)
{
    cat("\n\n")
    cat(paste0("Number of iterations: ", x$n.iter))
    cat("\n")
    cat(paste0("Burn-in period: ", x$burn.in))
    cat("\n\n")
    cat("Note: for diagnostics, use the package coda or MCMCpack.")
    cat("\n\n")
    invisible()
}


## * data functions


#' Classify data points
#'
#' This function returns a data frame with the data points classified according to the estimation of cluster probabilities generated by the output of the function \code{\link{hdpGLM}}
#'
#'
#' @param data a data frame with the data set used to estimate the \code{\link{hdpGLM}} model
#' @param samples the output of \code{\link{hdpGLM}} 
#'
#' @export
classify <- function(data, samples)
{
    cluster = apply(samples$pik, 1, which.max)
    return(data.frame(Cluster = cluster, data))
}

#' Deprecated
#'
#' @param data a data frame with the data set used to estimate the \code{\link{hdpGLM}} model
#' @param samples the output of \code{\link{hdpGLM}}
#' 
#' @export
hdpGLM_classify <- function(data, samples)
{
    cat("\n\nNote: use classify(). hdpGLM_classify() will be removed in future versions.\n\n")
    return(classify(data, samples))
}



#' Summary dpGLM data
#'
#' This function summarizes the data and parameters used to generate the data using the function hdpLGM.
#'
#' @param object an object of the class dpGLM_data
#' @param ... ignored 
#'
#' @return The function returns a list with the summary of the data produced by the standard summary function and a \code{data.frame} with the true values of beta for each cluster.
#'
#'
#' @export
summary.dpGLM_data <- function(object, ...)
{
    x = object
    betas = x$parameters$beta %>% 
        base::do.call(base::rbind,.) %>%
        base::cbind(., k=1:nrow(.))  %>%
        tibble::as_tibble(.)  %>%
        tidyr::gather(key = Parameter, value=True, -k) 
    return(list(data = summary(x$data),beta=betas))
}

#' Summary
#'
#' This functions summarizes the data simulated by the function \code{hdpGLM_simulateData} 
#'
#' @param object an object of the class \code{hdpGLM_data}, which is produced by the function \code{hdpGLM_simulateData} 
#' @param ... ignored 
#'
#' @return It returns a list with three elements. The first is a summary of the data, the second a tibble with the linear coefficients \code{beta} and their values used to generate the data, and the third element is also a tibble with the true values of \code{tau} used to generate the \code{betas}.
#'
#' @export
summary.hdpGLM_data <- function(object, ...)
{
    x = object
    betas = x$parameters$beta %>% 
        base::data.frame(.) %>%
        dplyr::mutate(Parameter=rownames(.)) %>% 
        tidyr::gather(key=j, value=True, -Parameter) %>%
        dplyr::mutate(k = gsub(x=j, 'j.[0-9]*.k.', ''),
                      j = gsub(x=j, '.k.[0-9]*', ''),
                      j = gsub(x=j, 'j.', ''),
                      j = as.numeric(as.character(j)),
                      k = as.numeric(as.character(k))) %>%
        dplyr::select(j,k,Parameter, True) %>%
        dplyr::as_tibble(.)  %>%
        dplyr::mutate(Parameter = paste0(
                          stringr::str_replace(string=Parameter, pattern="[0-9]*$", replacement=""),
                          "[",
                          as.numeric(stringr::str_replace(string=Parameter, pattern="beta", replacement=""))-1 ,
                      "]" )
                      )
    taus = x$parameters$tau %>%
        base::data.frame(dw = rownames(.), .) %>%
        tidyr::gather(key=dx, value=True, -dw) %>%
        dplyr::mutate_all(dplyr::funs(as.numeric(gsub(., pattern='[a-z]', replacement=''))) ) %>%
        dplyr::mutate(Parameter = paste0('tau[', dw,"][",dx,"]", sep='') )  %>%
        dplyr::select(dw,dx,Parameter, True)  %>%
        dplyr::rename(w = dw, beta = dx) 
    return(list(data = summary(x$data),beta=betas, tau=taus))
}

#' Plot simulated data
#'
#' Create a plot with the beta sampled from its distribution, as a function of context-level feature $W$. Only works for the hierarchical model (hdpGLM), not the dpGLM
#'
#'
#' @param data the output of the function \code{hdpGLM_simulateData} 
#' @param w.idx integer, the index of the context level covariate the plot
#' @param ncol integer, the number of columns in the grid of the plot
#'
#' @export
plot_beta_sim <- function(data, w.idx, ncol=NULL)
{
    W = data$parameter$W %>%
        tibble::as_tibble(.)  %>% 
        dplyr::mutate(W0=1, j=1:nrow(.))
    taus = summary(data)$tau %>%
                       tibble::as_tibble(.)  %>%
                       dplyr::mutate(beta = paste0('beta[', beta,"]", sep=''),
                                     w = paste0('W', w, sep=''),
                                     ## tau.label = paste0(stringr::str_extract(Parameter, 'tau') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']'),
                                     tau.label = Parameter,
                                     w.label   = paste0(stringr::str_extract(w, 'W') , '[', stringr::str_extract(w, '[0-9]+') ,']'))  %>%
                       dplyr::group_by(beta) %>%
                       dplyr::mutate(beta.exp        = paste0(paste0(tau.label, w.label), collapse="~+~"),
                                     beta.exp.values = paste0(paste0(round(True, 2), "~", w.label), collapse="~+~"),
                                     )  %>%
                       dplyr::ungroup(.)  %>%
                       dplyr::mutate(beta.exp        = stringr::str_replace(string=beta.exp, pattern="W\\[0\\]", replacement=""),
                                     beta.exp.values = stringr::str_replace(string=beta.exp.values, pattern="W\\[0\\]", replacement=""),
                                     ) 
    betas = summary(data)$beta %>%
                        dplyr::full_join(., W , by=c("j"))  %>%
                        tidyr::gather(key = w, value=W, dplyr::contains("W") )  %>%
                        dplyr::left_join(., taus %>% dplyr::select(dplyr::contains("beta")) , by=c('Parameter'="beta")) 
    parameters = betas %>%
        dplyr::full_join(., taus , by=c("w", "Parameter"="beta"), suffix=c(".beta", ".tau")) 
    parameters = parameters %>%
        dplyr::mutate(
                   ## Parameter = paste0("E(", stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,'])'),
                   ## Parameter.tau = paste0(stringr::str_extract(Parameter.tau,'tau'),'[',
                   ##                        stringr::str_extract(Parameter.tau,'[0-9]+') ,']'),
                   Parameter = paste0("E(", Parameter, ')'),
                   Parameter.tau = Parameter.tau,
                   order = paste0(stringr::str_extract(w, '[0-9]+'),  sep=''),
                   facet = paste0(Parameter, "==", beta.exp.values.tau)) 
    ## plot
    g = parameters %>% 
        tidyr::separate(., col=w, into=c("w.label", "w.idx"), sep="W", remove=FALSE) %>%
        dplyr::mutate(w.idx = as.numeric(w.idx))  %>%
        dplyr::filter(w.idx == !!w.idx) %>% 
        ggplot2::ggplot(.)+
        ggplot2::geom_point(ggplot2::aes(x=W, y=True.beta, colour=as.factor(k)), size=2, alpha=.5) +
        ggplot2::geom_smooth(ggplot2::aes(x=W, y=True.beta), colour='grey40', size=.5, fill='grey80', method="lm") +
        ggplot2::facet_wrap( ~ facet, ncol = ncol, scales='free', labeller=ggplot2::label_parsed) +
        ggplot2::ylab(bquote("True randomly generated values of "~beta))+
        ggplot2::theme_bw()+
        ggplot2::theme(strip.background = ggplot2::element_rect(colour="white", fill="white"),
                       strip.text.x = ggplot2::element_text(size=12, face='bold', hjust=0),
                       strip.text.y = ggplot2::element_text(size=12, face="bold", vjust=0)) +
        ggplot2::scale_colour_brewer(palette='Dark2', name="Cluster (k)")+
        ggplot2::xlab(bquote(W[.(w.idx)]))+
        ggplot2::theme(legend.position = "top") 
    return(g)
} 




#' Print
#'
#' Generic method to print the output of the \code{hdpGLM_simulateData} function
#'
#' @param x a \code{dpGLM_data} object returned by the function
#'          \code{hdpGLM_simulateData}
#' @param ... ignore 
#'
#' @return returns a summary of the simulated data
#'
#' @export
print.dpGLM_data <- function(x, ...)
{
    cat("\n\n")
    cat(paste("Sample size:", nrow(x$data)))
    cat("\n")
    cat(paste("Number of Clusters (Z):", length(unique(x$Z))))
    cat("\n")
    cat("\n")
    cat("Data (head):\n")
    cat("------------\n")
	print(head(x$data, 3))
    cat("...\n")
    for (cluster in unique(x$Z))
    {
        cat("\n")
        cat(paste("Cluster", cluster,
                  ": observation", min(which(x$Z==cluster)),
                  "to", max(which(x$Z==cluster))))
    }
    cat("\n\n")
    cat("Parameters:\n")
    cat("----------\n")
    cat("Cluster probabilities (pi): ")
    cat(round(x$parameters$pi, digits=4))
    cat("\n\n")
    cat(paste("Linear coefficients:\n"))
    coef = x$parameters$beta
    res = dplyr::bind_rows(coef) %>% 
        dplyr::mutate(`Cluster` = 1:length(coef)) %>% 
        dplyr::select("Cluster", dplyr::everything())
    print(as.data.frame(res))
    ## for (betas in (x$parameters$beta))
    ## {
    ##     cat("\n")
    ##     cat(paste("Cluster ", i, ":\n", sep=''))
    ##     print(betas)
    ##     i=i+1
    ## }
    cat("\n\n")
    invisible()
}


#' Print
#'
#' Generic method to print the output of the \code{hdpGLM_simulateData} function
#'
#' @param x a \code{hdpGLM_data} object returned by the function
#'          \code{hdpGLM_simulateData}
#' @param ... ignore 
#'
#' @return returns a summary of the simulated data
#'
#' @export
print.hdpGLM_data <- function(x, ...)
{
    cat("\n\n")
    cat(paste("Sample size:", nrow(x$data)))
    cat("\n")
    cat(paste("Number of Contexts (J): ", length(x$parameters$beta)))
    cat("\n")
    cat(paste("Number of Clusters (Z) per context (J): ", length(unique(x$Z))))
    cat("\n\n")
    cat("Data (head):\n")
    cat("------------\n")
	print(head(x$data, 3))
    cat("...\n\n")
    cat("Parameters:\n")
    cat("----------\n")
    cat("Cluster probabilities (pi): ")
    cat(round(x$parameters$pi, digits=4))
    cat("\n\n")
    cat(paste("Linear coefficients (beta):"))
    i = 1
    nJ = length(x$parameters$beta)
    for (i in c(1, 2, nJ))
    {
        cat("\n")
        cat('Context: ', i)
        cat("\n")
        coef = x$parameters$beta[[i]]
        res = dplyr::bind_rows(coef) %>% 
            dplyr::mutate(`Cluster` = names(coef)) %>% 
            dplyr::select("Cluster", dplyr::everything())
        print(as.data.frame(res))
        if (i==2) {
            cat("... \n")
        }
    }
    cat("\n\n")
    cat("Context-level coefficients (tau) for each linear coefficient (beta)")
    cat("\n")
    tau = x$parameters$tau
    row.names(tau) = stringr::str_replace(string=row.names(tau),
                               pattern="w", replacement="tau")
    colnames(tau) = paste('beta', 1:ncol(tau), sep='')
    print(tau)
    cat("\n\n")
    invisible()
}

## * ancillary functions

hdpGLM_match_clusters <- function(samples, true)
{
## ' Match Labels of Estimation and True Clusters
## '
## ' This function matches the estimated clusters and the true values of the linear coefficients of the clusters.
## '
## ' @param samples a \code{\link{hdpGLM}} or a \code{dpGLM} object, output of the function \code{\link{hdpGLM}}. 
## ' @param true a list with the true values of the parameters that generate the data (See \code{hdpGLM_simulateParameters})
## '
## ' @return The function returns a table with the labels of the clusters from the MCMC estimation provided by the function \code{\link{hdpGLM}} and the labels used to generate the data, which is in the column \code{True.Cluster.match} of the output. The output table also contains the true value of the linear coefficients that were used to generate the data.
## '
## ' @export
    if(methods::is(samples, 'dpGLM')){
        true = true  %>% dplyr::mutate(Parameter = as.character(Parameter))
        tab = summary_tidy(samples) %>%
            data.frame(., row.names=1:nrow(.)) %>%
            dplyr::filter(Parameter!="sigma") %>% 
            ## dplyr::select(Parameter, Cluster, Mean) %>% 
            dplyr::group_by(k) %>%
            dplyr::mutate(Parameter=as.character(Parameter),
                          True.Cluster.match = purrr::pmap_dbl(.l = list(estimate = list(tibble::tibble(Parameter=Parameter, Mean=Mean)),
                                                                         true     = list(true) ),
                                                               .f=function(estimate, true) hdpGLM_match_clusters_aux(estimate, true) ) ) %>%
            dplyr::ungroup(.) %>%
            dplyr::full_join(., true , by=c("True.Cluster.match"= "k", 'Parameter'='Parameter'))  %>%
            dplyr::select(k, True.Cluster.match, Parameter, True, Mean, Median, SD, dplyr::contains("HPD"))  %>%
            dplyr::arrange(True.Cluster.match) 
    }
    if(methods::is(samples, 'hdpGLM')){
        estimates = summary_tidy(samples)$beta 
        tab       = tibble::tibble() 
        for (j in unique(estimates$j))
        {
            truej = true[true$j==j,]
            tabj  = estimates[estimates$j==j,]
            tabj = tabj %>%
                dplyr::filter(Parameter!="sigma")  %>%
                dplyr::group_by(k) %>%
                dplyr::mutate(True.Cluster.match = purrr::pmap_dbl(.l = list(estimate = list(tibble::tibble(Parameter=Parameter, Mean=Mean)),
                                                                             true     = list(truej) ),
                                                                   .f=function(estimate, true) hdpGLM_match_clusters_aux(estimate, true) ),
                              Parameter=as.character(Parameter)) %>%
                dplyr::ungroup(.) %>%
                dplyr::left_join(., truej , by=c("True.Cluster.match"= "k", 'Parameter', "j"))
            tab = tab %>%
                dplyr::bind_rows(., tabj)  
        }
        tab = tab %>%
            dplyr::select(j, k, True.Cluster.match, Parameter, True, dplyr::everything())  %>% 
            dplyr::arrange(j, k, Parameter)
    }
    return(tab)

}

hdpGLM_match_clusters_aux <- function(estimate, true)
{
    return(
        true %>%
        dplyr::group_by(k) %>%
        dplyr::full_join(., estimate , by=c("Parameter"))  %>%
        dplyr::mutate(d = edist(True, Mean)) %>%
        dplyr::ungroup(.)  %>%
        dplyr::filter(d == min(d))  %>% 
        dplyr::select(k) %>%
        base::unique(.) %>%
        dplyr::pull(.)
        )
}
hdpGLM_get_occupied_clusters <- function(x)
{
# This function take the sample from the posterior and returns only the samples from the linear coefficients \code{beta} of the clusters with data points assigned to them.
    terms = attr(x$samples, 'terms')
    active = apply(x$pik, 1, which.max)
    active_in_each_context = tibble::tibble(k=active, j = x$context.index) %>%
        dplyr::filter(!base::duplicated(.)) %>%
        dplyr::mutate(flag = 'select')
    x$samples = x$samples %>%
        tibble::as_tibble(.) %>%  
        dplyr::full_join(., active_in_each_context, by=c('k', 'j')) %>%
        dplyr::filter(flag=='select') %>%
        dplyr::select(-flag) %>%
        as.matrix %>% 
        coda::as.mcmc(.)
    n.iter = attr(x$samples, 'mcpar')[2]
    attr(x$samples, 'mcpar')[2] = n.iter
    attr(x$samples, 'terms') = terms
    return(x)
}
dpGLM_get_occupied_clusters <- function(x)
{
    ##  This function takes the sample from the posterior and returns only the samples of the clusters with data points assigned to them
    terms = attr(x$samples, 'terms')
    active = unique(apply(x$pik, 1, which.max))
    idx_active = x$samples[,'k'] %in% active
    n.iter = attr(x$samples, 'mcpar')[2]
    x$samples = coda::as.mcmc(x$samples[idx_active,])
    attr(x$samples, 'mcpar')[2] = n.iter
    attr(x$samples, 'terms') = terms
    return(x)
}
dpGLM_select_non_zero <- function(x, select_perc_time_active=60)
{
    summary_post = summary_tidy(x)
    summary_post = data.frame(parameter=row.names(summary_post),summary_post, row.names=1:nrow(summary_post))  %>% tibble::as_tibble()

    clusters_active = summary_post %>%
        dplyr::filter(parameter != "sigma") %>% 
        dplyr::select(parameter, Cluster, dplyr::contains("HPD"), Percentage.of.Iter..Cluster.was.active) %>% 
        dplyr::group_by(Cluster) %>% 
        dplyr::filter(!all(HPD.l <= 0 & 0 <= HPD.u), Percentage.of.Iter..Cluster.was.active>select_perc_time_active) %>%
        dplyr::ungroup() %>%
        dplyr::select(Cluster) %>%
        unique %>%
        dplyr::pull()

    n.iter = attr(x$samples, 'mcpar')[2]
    x$samples = coda::as.mcmc(x$samples[x$samples[,"k"] %in% clusters_active,])
    attr(x$samples, 'mcpar')[2] = n.iter

    return(x)
}


## * Plot tau and posterior expectaiton of beta

#' Plot tau
#'
#' Function to plot posterior distribution of tau
#'
#'
#' @param samples an output of the function \code{\link{hdpGLM}}
#' @param X a string vector with the name of the first-level covariates whose associated tau should be displayed
#' @param W a string vector with the name of the context-level covariate(s) whose linear effect will be displayed. If \code{NULL}, the linear effect tau of all context-level covariates are displayed. Note: the context-level covariate must have been included in the estimation of the model.
#' @param true.tau a \code{data.frame} with four columns. The first must be named \code{w} and it indicates the index of each context-level covariate, starting with 0 for the intercept term. The second column named \code{beta} must contain the indexes of the betas of individual-level covariates, starting with 0 for the intercept term. The third column named \code{Parameter} must be named \code{tau<w><beta>}, where \code{w} and \code{beta} must be the actual values displayed in the columns \code{w} and \code{beta}. Finally, it must have a column named \code{True} with the true value of the parameter.
#' @inheritParams summary.hdpGLM
#' @param show.all.taus  boolean, if \code{FALSE} (default) the posterior distribution of taus representing the intercept of the expectation of beta are omitted
#' @param show.all.betas boolean, if \code{FALSE} (default) the taus affecting only the intercept terms of the outcome variable are omitted
#' @param ncol number of columns of the grid. If \code{NULL}, one column is used
#' @param title string, title of the plot
#' @inheritParams plot.hdpGLM 
#'
#' @examples
#'
#' library(magrittr)
#' set.seed(66)
#' 
#' # Note: this example is just for illustration. MCMC iterations are very reduced
#' set.seed(10)
#' n = 20
#' data.context1 = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:3, n, replace=TRUE),
#'                                    y  =I(z==1) * (3 + 4*x1 - x2 + rnorm(n)) +
#'                                        I(z==2) * (3 + 2*x1 + x2 + rnorm(n)) +
#'                                        I(z==3) * (3 - 4*x1 - x2 + rnorm(n)) ,
#'                                    w = 20
#'                                    ) 
#' data.context2 = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:2, n, replace=TRUE),
#'                                    y  =I(z==1) * (1 + 3*x1 - 2*x2 + rnorm(n)) +
#'                                        I(z==2) * (1 - 2*x1 +   x2 + rnorm(n)),
#'                                    w = 10
#'                                    ) 
#' data = data.context1 %>%
#'     dplyr::bind_rows(data.context2)
#' 
#' ## estimation
#' mcmc    = list(burn.in=1, n.iter=50)
#' samples = hdpGLM(y ~ x1 + x2, y ~ w, data=data, mcmc=mcmc, n.display=1)
#' 
#' 
#' plot_tau(samples)
#' plot_tau(samples, ncol=2)
#' plot_tau(samples, X='x1', W='w')
#' plot_tau(samples, show.all.taus=TRUE, show.all.betas=TRUE, ncol=2)
#' 
#' @export
plot_tau <- function(samples, X=NULL, W=NULL, title=NULL, true.tau=NULL, 
                     show.all.taus=FALSE, show.all.betas=FALSE, ncol=NULL, 
                     legend.position='top', x.axis.size=1.1, y.axis.size=1.1, 
                     title.size=1.2, panel.title.size=1.4, legend.size=1, 
                     xlab=NULL)
{
    ## keep all default options
    op.default <- options()
    on.exit(options(op.default), add=TRUE)
    ## keep current working folder on exit
    dir.default <- getwd()
    on.exit(setwd(dir.default), add=TRUE)
    ## no warning messages
    ## options(warn=-1)
    ## on.exit(options(warn=0))

    Dw = samples$context.cov %>% ncol - 1
    Dx = samples$samples %>% colnames %>% stringr::str_detect(., pattern="beta") %>% sum - 1
    terms = attr(samples$tau, "terms")   %>% dplyr::rename(beta.label = beta)
    summary.tau = summary_tidy(samples)$tau
    if (!is.null(true.tau)) {
        summary.tau = summary.tau %>%
            dplyr::left_join(., true.tau, by=c("Parameter", "w", "beta"))  
    }

    tab = samples$tau %>% 
        tibble::as_tibble()  %>% 
        tidyr::gather(key = Parameter, value=value)  %>% 
        dplyr::mutate(Parameter = as.character(Parameter)) %>%  
        dplyr::left_join(., summary.tau %>% dplyr::mutate(Parameter = as.character(Parameter)), by=c("Parameter"))  %>%
        dplyr::left_join(., terms %>% dplyr::mutate_if(is.factor, as.character), by=c("Parameter"))  %>%
        dplyr::mutate(#Parameter = stringr::str_replace(string=Parameter, pattern="[0-9]*$", replacement="") %>% paste0("",.,'[',stringr::str_replace(string=Parameter, pattern="tau", replacement=""),"]"),
                      ## Parameter = stringr::str_replace_all(string=Parameter, pattern="0", replacement="~o"),
                      term.beta = stringr::str_replace(string=term.beta, pattern="\\(", replacement=""),
                      term.beta = stringr::str_replace(string=term.beta, pattern="\\)", replacement=""),
                      term.tau  = stringr::str_replace(string=term.tau, pattern="\\(", replacement=""),
                      term.tau  = stringr::str_replace(string=term.tau, pattern="\\)", replacement=""),
                      facet     = dplyr::case_when(term.tau == 'Intercept' & term.beta == "Intercept" ~ paste0(Parameter, "~(Intercept~of~expectation~of~", beta.label,")"),
                                                   term.tau == 'Intercept' & term.beta != "Intercept" ~ paste0(Parameter, "~(Intercept~of~expectation~of~", beta.label,")"),
                                                   term.tau != 'Intercept'  ~ paste0("atop(",Parameter, "(~effect~of~",
                                                                                     stringr::str_replace_all(string=term.tau, pattern=" ", replacement="~")  , "~on~the~expectation~of~the~effect~of~",
                                                                                     stringr::str_replace_all(string=term.beta, pattern=" ", replacement="~") ,"(", beta.label,")))")
                                                   )
                      )
    
    ## select parameters to display in the plot
    if (!show.all.taus) tab = tab %>% dplyr::filter(!stringr::str_detect(Parameter, pattern="^tau\\[0")) 
    if (!show.all.betas) tab = tab %>% dplyr::filter(beta.label != "beta[0]") 
    ## select taus for beta_i
    if (!is.null(X)) {
        tab = tab %>% 
            dplyr::filter(term.beta %in% X)
    }
    if (!is.null(W)) {
        tab = tab %>%
            dplyr::filter(stringr::str_detect(Description, pattern=W %>% paste0(., collapse="|"))) 
    }
    if (is.null(ncol)) {
        ncol = 1
    }
        
    g = tab %>%
        ggplot2::ggplot(.) +
        ggplot2::geom_density(ggplot2::aes(x=value), fill="#00000044", adjust=1, alpha=.3, colour='white')  +
        ggplot2::geom_vline(ggplot2::aes(xintercept=Mean,  linetype="Mean", col='Mean'))+
        ggplot2::geom_vline(ggplot2::aes(xintercept=HPD.lower,  linetype="95% HPD", col='95% HPD'))+
        ggplot2::geom_vline(ggplot2::aes(xintercept=HPD.upper,  linetype="95% HPD", col='95% HPD'))+
        ggplot2::theme_bw() +
        ggplot2::theme(strip.background = ggplot2::element_rect(colour="white", fill="white"),
                       strip.text.x = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face='bold', hjust=0),
                       strip.text.y = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face="bold", vjust=0))  +
        ggplot2::theme(legend.position = legend.position) +
        ggplot2::scale_x_continuous(expand = c(0.0001, 0.0001)) +
        ggplot2::scale_y_continuous(expand = c(0, 0)) +
        ggplot2::facet_wrap( ~ facet , ncol=ncol, scales='free',labeller = ggplot2::label_parsed)  
 
    
    if (!is.null(true.tau)) {
        g = g +
            ggplot2::geom_vline(ggplot2::aes(xintercept=True,  linetype="True", col='True'))+
            ggplot2::scale_linetype_manual(values = c("dashed", "solid" , "solid"),
                                           labels=c("95% HPD", "Mean", "True"), name=' ') +
            ggplot2::scale_colour_manual(values = c("Mean"='black', "95% HPD" = "black", "Mean" = "black",
                                                    "True"="red"), name=' ') 
    }else{
        g = g +
            ggplot2::scale_linetype_manual(values = c("dashed", "solid" ), labels=c("95% HPD", "Mean"), name=' ') +
            ggplot2::scale_colour_manual(values = c("Mean"='black', "95% HPD" = "black", "Mean" = "black"),
                                         name=' ') 

    }
    if (is.null(title)) {
        g = g +
            ggplot2::ggtitle(label="Posterior distribution of context effect", subtitle="")
    }else{
        g = g +
            ggplot2::ggtitle(label=title, subtitle="")
    }
    if (!is.null(xlab)) {
        g = g +
            ggplot2::xlab(xlab) 
    }
    g = g +
        ggplot2::theme(axis.title.x = ggplot2::element_text(size=ggplot2::rel(x.axis.size), angle=0, ),
                       axis.title.y = ggplot2::element_text(size=ggplot2::rel(y.axis.size)),
                       plot.title   = ggplot2::element_text(size=ggplot2::rel(title.size)),
                       legend.text  = ggplot2::element_text(size=ggplot2::rel(legend.size))
                       )
    return(g)

}



#' Plot beta posterior distribution
#'
#' Plot the posterior distribution of the linear parameters beta for each context
#'
#'
#' @inheritParams plot_tau
#' @inheritParams plot.hdpGLM
#' @inheritParams plot.dpGLM
#' @param plot.grid boolean, if \code{TRUE} a grid is displayed in the background
#' @param showKhat boolean, if \code{TRUE} a message with the number of estimated clusters by context is displayed
#' @param col.border string, color of the border of the densities
#' @param col string, color of the densities
#' @param xlab.size numeric, size of the breaks in the x-axis 
#' @param ylab.size numeric, size of the breaks in the y-axis
#' @param title.size numeric, size of the title
#' @param legend.size numeric, size of the legend
#' @param xtick.distance numeric, distance between x-axis marks and bottom of the figure
#' @param ytick.distance numeric, distance between y-axis marks and bottom of the figure
#' @param left.margin numeric, distance between left margin and left side of the figure
#'
#' @export
plot_beta <- function(samples, X=NULL, context.id=NULL, true.beta=NULL,
                      title=NULL, subtitle=NULL, plot.mean=FALSE,
                      plot.grid=FALSE, showKhat=FALSE, col=NULL, xlab.size=NULL, 
                      ylab.size=NULL, title.size=NULL, legend.size=NULL, 
                      xtick.distance=NULL, left.margin=0, ytick.distance=NULL, 
                      col.border='white')
{
    ## other options
    par.default <- graphics::par(no.readonly = TRUE)
    on.exit(graphics::par(par.default), add=TRUE)
    ## keep all default options
    op.default <- options()
    on.exit(options(op.default), add=TRUE)
    ## keep current working folder on exit
    dir.default <- getwd()
    on.exit(setwd(dir.default), add=TRUE)
    ## no warning messages
    ## options(warn=-1)
    ## on.exit(options(warn=0))
    ## get context indexes and context labels to plot
    ## ----------------------------------------------
    C = samples$context.cov$C %>% unique %>% sort
    if (!is.null(context.id)) {
        j.label = samples$context.cov[,context.id] %>% dplyr::pull(.) %>% as.character
        
    }else{
        j.label = paste0(" ", C) 
    }
    ## get summary of the estimation
    ## -----------------------------
    if (!is.null(true.beta)) {
        summ   = summary_tidy(samples, true.beta=true.beta)
    }else{
        summ   = summary_tidy(samples)
    }
    beta   = summ$beta %>% dplyr::filter(term == !!X) %>% dplyr::select(Parameter)  %>% unique %>% dplyr::pull(.)
    xlim   = summ$beta %>% dplyr::summarise(lower = min(HPD.lower), upper=max(HPD.upper)) 
    xlim   = c(xlim$lower, xlim$upper)
    title  = paste0(beta,'(',X,')')
    ## plot layout
    ## -----------
    lspace = max(max(nchar(j.label)) -.4*max(nchar(j.label)), 2) + left.margin
    graphics::par(las=2,cex.axis=1.2, bty='n', pch=20, cex.main=.9, mar=c(0,lspace,0,0), mgp = c(2,.6,0))
    mat = matrix(c(C,length(C)+1:2), byrow=T, nrow=length(C)+2)
    graphics::layout(mat, heights=c(2,rep(1, nrow(mat)-2), 2))
    ## Plot aesthetics
    ## ---------------
    if(is.null(col)) col="grey"
    if((is.null(xlab.size))) xlab.size=.8
    if((is.null(ylab.size))) ylab.size=.8
    if(is.null(title.size)) title.size=1
    if (is.null(legend.size)) legend.size=1
    if(is.null(xtick.distance)) xtick.distance=0
    if (is.null(ytick.distance)) ytick.distance=0
    ## plot title and legend
    ## ---------------------
    graphics::plot(xlim[1], xlim[2], xlim=xlim, bty='n', xaxt='n', yaxt='n',  axes=FALSE, type="n",   ylab="", xlab="")
    ## ;
    ## title(parse(text=title),   outer = F, cex.main=title.size, size=3, adj=0)
    ## ;
    graphics::legend("topleft", legend=parse(text=title), cex=title.size, bty='n', adj=.5)
    ## ;
    ## ;
    if (plot.mean & !is.null(true.beta))
        graphics::legend("center", legend=c("Clusters Posterior Mean", "True"), lwd=c(1,1), col=c("black", "red"), bty="n", horiz=T)
    if (plot.mean & is.null(true.beta))
        graphics::legend("center", legend=c("Clusters Posterior Mean"), lwd=c(1), col=c("black"), bty="n", horiz=T)
    if (!plot.mean & !is.null(true.beta))
        graphics::legend("center", legend=c("True"), lwd=c(1), col=c("red"), bty="n", horiz=T, cex=legend.size)
    ## plot densities
    ## --------------
    for (j in C)
    {
        ## Debug/Monitoring message --------------------------
        msg <- paste0('\n','Generating plot with posterior density for context ', j,  '...'); cat(msg)
        ## ---------------------------------------------------
        if (!is.null(true.beta)) {
            True = summ$beta %>%
                dplyr::filter(j==!!j) %>%
                dplyr::filter(Parameter == !!beta)  %>%
                dplyr::select(True) %>%
                dplyr::pull(.)
        }
        if (plot.mean | showKhat) {
            Mean = summ$beta %>%
                dplyr::filter(j==!!j) %>%
                dplyr::filter(Parameter == !!beta)  %>%
                dplyr::select(Mean) %>%
                dplyr::pull(.)
            K_estimated = length(Mean)
        }
        g = samples$samples %>%
            tibble::as_tibble()  %>%
            dplyr::filter(j == !!j)  %>%
            dplyr::select(beta)  %>%
            dplyr::pull(.) %>% 
            stats::density(.)
        graphics::plot(g,  col=col,
             xlim=xlim,
             main='',
             xaxt='n',
             yaxt='n',
             ylab='')
        if (plot.grid) {graphics::grid()}
        graphics::mtext(j.label[j], side=2, line=1, cex=ylab.size, adj=ytick.distance) # plot ylabs
        graphics::polygon(g,col=col, border=col.border, lty=.5)
        graphics::axis(side =2, labels = F,   lwd=0.01, lwd.ticks=0, las=1, col='grey') ## vertical ine to mark y-axis
        ## graphics::axis(side =1, labels = F,   lwd=0.01, lwd.ticks=0, las=1, col='grey') ## vertical ine to mark x-axis
        ## graphics::axis(side =3, labels = F,   lwd=0.01, lwd.ticks=0, las=1, col='grey') ## vertical ine to mark x-axis
        if (!is.null(true.beta)) graphics::abline(v=True, col='red', lwd=.8)
        if (plot.mean)           graphics::abline(v=Mean, col='black', lwd=2)
        if (showKhat) graphics::legend('topleft', legend=paste0("Estimated Clusters: ", K_estimated) , bty='n')
    }
    graphics::plot(xlim[1], 0, xlim=xlim, bty='n', xaxt='n', yaxt='n',  axes=FALSE, type="n", ylab="", xlab="")
    graphics::axis(1, las=0, pos=1, outer=T, lwd=.01, cex.axis=xlab.size, padj=xtick.distance, tck=-.1, col='grey')
    cat('\n')
    invisible()
}



#' Plot beta posterior expectation
#'
#' This function plots the posterior expectation of beta, the linear effect of the individual level covariates, as function of the context-level covariates
#'
#'
#' @inheritParams plot_tau
#' @inheritParams plot.hdpGLM
#' @param smooth.line boolean, if \code{TRUE} the plot will display a regression line representing the regression of the posterior expectation of the linear coefficients betas on the context-level covariates. Default \code{FALSE}
#' @param pred.pexp.beta boolean, if \code{TRUE} the plots will display a line with the predicted posterior expectation of betas obtained using the posterior expectation of taus, the linear coefficients of the expectation of beta
#' @param ncol.beta integer with number of columns of the grid used for each group of context-level covariates
#' @param nrow.w integer with the number of rows of the grid
#' @param ncol.w integer with the number of columns of the grid
#' @param ylab string, the label of the y-axis
#' @param title string, title of the plot
#' @param col.pred.line string with color of fitted line. Only works if \code{pred.pexp.beta=TRUE}
#' @param title.size numeric, absolute size of the title 
#'
#' @examples
#' 
#' library(magrittr)
#' set.seed(66)
#' 
#' # Note: this example is just for illustration. MCMC iterations are very reduced
#' set.seed(10)
#' n = 20
#' data.context1 = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:3, n, replace=TRUE),
#'                                    y  =I(z==1) * (3 + 4*x1 - x2 + rnorm(n)) +
#'                                        I(z==2) * (3 + 2*x1 + x2 + rnorm(n)) +
#'                                        I(z==3) * (3 - 4*x1 - x2 + rnorm(n)) ,
#'                                    w = 20
#'                                    ) 
#' data.context2 = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:2, n, replace=TRUE),
#'                                    y  =I(z==1) * (1 + 3*x1 - 2*x2 + rnorm(n)) +
#'                                        I(z==2) * (1 - 2*x1 +   x2 + rnorm(n)),
#'                                    w = 10
#'                                    ) 
#' data = data.context1 %>%
#'     dplyr::bind_rows(data.context2)
#' 
#' ## estimation
#' mcmc    = list(burn.in=1, n.iter=50)
#' samples = hdpGLM(y ~ x1 + x2, y ~ w, data=data, mcmc=mcmc, n.display=1)
#' 
#' plot_pexp_beta(samples)
#' plot_pexp_beta(samples, X='x1', ncol.w=2, nrow.w=1)
#' plot_pexp_beta(samples, X='x1', ncol.beta=2)
#' plot_pexp_beta(samples, pred.pexp.beta=TRUE, W="w", X=c("x1", "x2"))
#' plot_pexp_beta(samples, W='w', smooth.line=TRUE, pred.pexp.beta=TRUE, ncol.beta=2)
#' 
#' @export
plot_pexp_beta <- function(samples, X=NULL, W=NULL, pred.pexp.beta=FALSE, 
                           ncol.beta=NULL, ylab=NULL, nrow.w=NULL, ncol.w=NULL,
                           smooth.line=FALSE, title=NULL, legend.position='top',
                           col.pred.line='red', x.axis.size=1.1, 
                           y.axis.size=1.1, title.size=12, panel.title.size=1.4,
                           legend.size=1)
{
    ## Debug/Monitoring message --------------------------
    msg <- paste0('\n','\nGenerating plots ...\n',  '\n'); cat(msg)
    ## ---------------------------------------------------
    ## keep all default options
    op.default <- options()
    on.exit(options(op.default), add=TRUE)
    ## keep current working folder on exit
    dir.default <- getwd()
    on.exit(setwd(dir.default), add=TRUE)
    ## no warning messages
    ## options(warn=-1)
    ## on.exit(options(warn=0))

    ## samples
    samples = hdpGLM_get_occupied_clusters(samples)
    Ws = samples$context.cov  %>% dplyr::select(-C) %>% names
    if (is.null(W)) {
        W = Ws
    }else{
        if (!all(W %in% Ws)) {
            stop("\n\nW must contains names of context-level covariates used to estimate the model.\n\n")
        }
    }
    ## dat = data %>%
    ##     dplyr::select(W)  %>%
    ##     dplyr::filter(!duplicated(.)) %>%
    ##     dplyr::left_join(., samples$context.cov) 
    summ = summary_tidy(samples)
    taus  = summ$tau
    betas = summ$beta %>%
        dplyr::filter(Parameter != 'sigma')  %>%
        dplyr::mutate(Parameter.label = paste0(stringr::str_extract(Parameter, 'beta') , '[', stringr::str_extract(Parameter, '[0-9]+') ,']'),
                      term.label = stringr::str_replace_all(string=term, pattern="\\)|\\(", replacement=""),
                      term.label = paste0("(", stringr::str_replace_all(string=term.label, pattern=" ", replacement="~") , ")") )  %>%
        tidyr::unite(Parameter.facet, Parameter.label, term.label, sep="~", remove=FALSE) %>%
        dplyr::left_join(., samples$context.cov, by=c('j' = 'C')) 
                           
    
    if (!is.null(X)) {
        betas = betas %>%
            dplyr::filter(stringr::str_detect(term, pattern=paste0(X, collapse="|") )) 
    }else{
        X = betas$term %>% unique
    }

    if (is.null(ncol.beta)) {
        ncol.beta = betas$Parameter %>% unique %>% length
    }
    if (is.null(nrow.w)) {
        nrow.w = length(W)
    }
    plots = list()
    for (i in 1:length(W))
    {
        w = W[i]
        if (is.null(ncol)) 
            ncol = betas$Parameter %>% unique %>% length
        plots[[i]] = betas %>%
            dplyr::mutate(k = as.factor(k))  %>% 
            dplyr::select(k, Parameter.facet, Mean, w) %>%
            ## tidyr::gather(key = W, value=value, -Parameter.facet, -Mean, -k)  %>%
            ggplot2::ggplot(.) +
            ## ggplot2::geom_point(ggplot2::aes_string(x=w, y="Mean", colour="k"), size=2) +
            ggplot2::geom_point(ggplot2::aes_string(x=w, y="Mean"), colour="#00000044", size=2) +
            ## facet_grid( W ~ Parameter.facet ,  scales='free_x',labeller=label_parsed ) +
            ggplot2::facet_wrap( ~  Parameter.facet , ncol=ncol.beta, scales='free',labeller=ggplot2::label_parsed ) +
            ggplot2::scale_colour_brewer(palette='BrBG', name="Cluster") +
            ggplot2::theme_bw()+
            ggplot2::theme(strip.background = ggplot2::element_rect(colour="white", fill="white"),
                           strip.text.x = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face='bold', hjust=0),
                           strip.text.y = ggplot2::element_text(size=ggplot2::rel(panel.title.size), face="bold", vjust=0)) +
         ggplot2::guides(colour=FALSE)

        if (is.null(ylab)) {
            plots[[i]] = plots[[i]] +
                ## ggplot2::ylab(bquote(E(beta~"|"~ .)~"(Posterior expectation of "~beta~")")) 
                ggplot2::ylab("Clusters Posterior Average") 
        }
        if (smooth.line) {
            plots[[i]] = plots[[i]] + ggplot2::geom_smooth(ggplot2::aes_string(x=w, y="Mean"), colour='grey40', size=.5, fill='grey80', method="lm") 
        }
        if (pred.pexp.beta) {
            pred.betas = fit_pexp_beta(samples, W=w)
            pred = pred.betas %>%
                dplyr::left_join(., betas %>% dplyr::select(Parameter, Parameter.facet, term) , by=c("beta"="Parameter")) %>%
                dplyr::filter(term %in% X) 
            plots[[i]] = plots[[i]] +
                ggplot2::geom_line(data= pred %>% dplyr::mutate(linetype="Fitted line using posterior \nexpectation of context effect") ,
                                   ggplot2::aes_string(x=w, y="E.beta.pred", group="beta",  linetype="linetype"), colour=col.pred.line) +
                ggplot2::scale_linetype_manual(values = "solid", name="")             
        }
        plots[[i]] = plots[[i]] +
            ggplot2::theme(axis.title.x = ggplot2::element_text(size=ggplot2::rel(x.axis.size), angle=0, ),
                           axis.title.y = ggplot2::element_text(size=ggplot2::rel(y.axis.size)),
                           plot.title   = ggplot2::element_text(size=ggplot2::rel(title.size)),
                           legend.text  = ggplot2::element_text(size=ggplot2::rel(legend.size))
                           )
        
    }
    if (!is.null(title)) {
        g = ggpubr::ggarrange(plotlist=plots, nrow=nrow.w, ncol=ncol.w, common.legend=T) %>%
            ggpubr::annotate_figure(., top = ggpubr::text_grob(title, color = "black", size =title.size))
    }else{
        g = ggpubr::ggarrange(plotlist=plots, nrow=nrow.w, ncol=ncol.w, common.legend=T, legend=legend.position)
    }
    return(g)
}

fit_pexp_beta <- function(samples, W=NULL)
{
    betas.list = summary_tidy(samples)$tau %>%
                                dplyr::mutate(Description = stringr::str_replace_all(string=Description, pattern="of |Effect of | on", replacement=""))  %>%
                                tidyr::separate(., col=Description, into=c("W", 'beta'), sep=' ') %>%
                                base::split(., .$beta) %>%
                                purrr::map(.x=., function(.x) .x %>% dplyr::arrange(Parameter) )
    dat = samples$context.cov %>% cbind(., Intercept=1)  %>% dplyr::select(betas.list[[1]]$W)  %>% tibble::as_tibble() 
    new_datas = list()
    for (i in 1:length(W))
    {
        w = W[i]
        new_datas[[i]] = hdpglm_get_new_data(dat, 100, w)
        names(new_datas)[i] = w
    }


    pred = data.frame()
    pred.lower= data.frame()
    pred.upper = data.frame()
    for (i in 1:length(betas.list))
    {
        E.tau       = matrix(betas.list[[i]]$Mean     , nrow=nrow(betas.list[[i]]) )
        E.tau.lower = matrix(betas.list[[i]]$HPD.lower, nrow=nrow(betas.list[[i]]) )
        E.tau.upper = matrix(betas.list[[i]]$HPD.upper, nrow=nrow(betas.list[[i]]) )
        beta        = names(betas.list)[i]
        for (j in 1:length(new_datas))
        {
            context.cov       = names(new_datas)[j]
            W.new             = new_datas[[j]] %>% as.matrix
            E.beta.pred       = W.new %*% E.tau
            ## E.beta.pred.lower = W.new %*% (E)
            ## E.beta.pred.upper = W.new %*% E.tau.upper 
            pred       = rbind(pred      ,data.frame(E.beta.pred      , W.new, beta = beta, cov=context.cov))
            ## pred.lower = rbind(pred.lower,data.frame(E.beta.pred.lower, W.new, beta = beta, cov=context.cov))
            ## pred.upper = rbind(pred.upper,data.frame(E.beta.pred.upper, W.new, beta = beta, cov=context.cov))
        }
    }
    pred = pred %>%
        ## dplyr::left_join(., pred.lower, by=c())  %>%
        ## dplyr::left_join(., pred.upper, by=c())  %>% 
        tibble::as_tibble() 
    return(pred)
}
hdpglm_get_new_data          <- function(data, n, x, cat.values=NULL)
{
    if(!is.null(cat.values)){
        cat_vars1 = cat.values %>%
            do.call(expand.grid,.) %>%
            base::replicate(n, ., simplify = FALSE) %>%
            dplyr::bind_rows(.) %>%
            dplyr::arrange_(.dots=names(cat.values)) 
        cat_vars2 = data %>%
            tibble::as_tibble(.) %>%
            dplyr::select_if(function(col) !is.numeric(col)) %>%
            dplyr::select(-dplyr::one_of(names(cat.values)))  %>%
            dplyr::summarize_all(function(x) sort(unique(x))[1])  %>%
            dplyr::mutate_if(is.factor, as.character) %>%
            base::replicate(nrow(cat_vars1), ., simplify=FALSE)  %>%
            dplyr::bind_rows(.)
        cat_vars = dplyr::bind_cols(cat_vars1, cat_vars2)  %>%
            tibble::as_tibble(.) 

        num_vars = data %>%
            dplyr::select_if(is.numeric)  %>%
            dplyr::mutate_all(mean, na.rm=T) %>%
            dplyr::slice(., 1) %>% 
            .[rep(1,nrow(cat_vars)),] %>%
            dplyr::bind_rows(.)

        newdata = dplyr::bind_cols(num_vars, cat_vars)

        newx = newdata %>%
            dplyr::group_by_(.dots=names(cat.values)) %>%
            dplyr::mutate( x = seq(min(data[,x], na.rm=T), max(data[,x], na.rm=T),length=n()))  %>%
            dplyr::ungroup(.)  %>% 
            dplyr::select(x) 

        newdata[,x] = newx
    }else{
        cat_vars = data %>%
            tibble::as_tibble(.) %>%
            dplyr::select_if(function(col) !is.numeric(col)) %>%
            dplyr::summarize_all(function(x) sort(unique(x))[1])  %>%
            dplyr::mutate_if(is.factor, as.character) %>%
            .[rep(1,n),]
        num_vars = data %>%
            dplyr::select_if(is.numeric)  %>%
            dplyr::mutate_all(mean, na.rm=T) %>%
            dplyr::slice(., 1) %>% 
            .[rep(1,n),]
        num_vars[,x] = seq(min(data[,x], na.rm=T),max(data[,x], na.rm=T),length=n)
        newdata = cat_vars %>% dplyr::bind_cols(., num_vars)
    }

    newdata = newdata %>% dplyr::mutate_if(is.factor, as.character)
    return(newdata)
}

## ** plot_hdpglm 

#' Plot posterior distributions
#'
#' this function creates a plot with two grids. One is the grid with posterior
#' expectation of betas as function of context-level covariates. The other is
#' the posterior distribution of tau
#'
#'
#' @inheritParams plot_tau
#' @inheritParams plot_pexp_beta
#' @param ncol.taus integer with the number of columns of the grid containing the posterior distribution of tau
#' @param ncol.betas integer with the number of columns of the posterior expectation of betas as function of context-level features
#' @param ncol.w integer with the number of columns to use to display the different context-level covariates
#' @param nrow.w integer with the number of rows to use to display the different context-level covariates
#' @param title.tau string, the title for the posterior distribution of the context effects
#' @param title.beta string, the title for the posterior expectation of beta as function of context-level covariate
#' @param tau.x.axis.size numeric, relative size of the x-axis of the plot with tau
#' @param tau.xlab string, the label of the x-axis for the plot with tau 
#' @param tau.y.axis.size numeric, relative size of the y-axis of the plot with tau
#' @param tau.title.size numeric, relative size of the title of the plot with tau
#' @param tau.panel.title.size numeric, relative size of the title of the panels of the plot with tau
#' @param tau.legend.size numeric, relative size of the legend of the plot with tau
#' @param beta.x.axis.size numeric, relative size of the x-axis of the plot with beta
#' @param beta.y.axis.size numeric, relative size of the y-axis of the plot with beta
#' @param beta.title.size numeric, relative size of the title of the plot with beta
#' @param beta.panel.title.size numeric, relative size of the title of the panels of the plot with beta
#' @param beta.legend.size numeric, relative size of the legend of the plot with beta
#'
#' @examples
#' 
#' library(magrittr)
#' # Note: this example is just for illustration. MCMC iterations are very reduced
#' set.seed(10)
#' n = 20
#' data.context1 = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:3, n, replace=TRUE),
#'                                    y  =I(z==1) * (3 + 4*x1 - x2 + rnorm(n)) +
#'                                        I(z==2) * (3 + 2*x1 + x2 + rnorm(n)) +
#'                                        I(z==3) * (3 - 4*x1 - x2 + rnorm(n)) ,
#'                                    w = 20
#'                                    ) 
#' data.context2 = tibble::tibble(x1 = rnorm(n, -3),
#'                                    x2 = rnorm(n,  3),
#'                                    z  = sample(1:2, n, replace=TRUE),
#'                                    y  =I(z==1) * (1 + 3*x1 - 2*x2 + rnorm(n)) +
#'                                        I(z==2) * (1 - 2*x1 +   x2 + rnorm(n)),
#'                                    w = 10
#'                                    ) 
#' data = data.context1 %>%
#'     dplyr::bind_rows(data.context2)
#' 
#' ## estimation
#' mcmc    = list(burn.in=1, n.iter=50)
#' samples = hdpGLM(y ~ x1 + x2, y ~ w, data=data, mcmc=mcmc, n.display=1)
#' 
#' plot_hdpglm(samples)
#' plot_hdpglm(samples, ncol.taus=2, ncol.betas=2, X='x1')
#' plot_hdpglm(samples, ncol.taus=2, ncol.betas=2, X='x1', ncol.w=2, nrow.w=1,
#'             pred.pexp.beta=TRUE,smooth.line=TRUE )
#' 
#'
#' @export
plot_hdpglm <- function(samples, X=NULL, W=NULL, ncol.taus=1, ncol.betas=NULL,
                        ncol.w=NULL, nrow.w=NULL, smooth.line=FALSE,
                        pred.pexp.beta=FALSE, title.tau=NULL, true.tau=NULL,
                        title.beta=NULL, tau.x.axis.size=1.1,
                        tau.y.axis.size=1.1, tau.title.size=1.2,
                        tau.panel.title.size=1.4, tau.legend.size=1,
                        beta.x.axis.size=1.1, beta.y.axis.size=1.1,
                        beta.title.size=1.2, beta.panel.title.size=1.4,
                        beta.legend.size=1, tau.xlab=NULL)
{
    
    ## Debug/Monitoring message --------------------------
    msg <- paste0('\n','\nPlot being generated ...\n',  '\n'); cat(msg)
    ## ---------------------------------------------------
    g1 = plot_tau(samples, X=X, W=W, ncol=ncol.taus, title=title.tau,
                  true.tau=true.tau, x.axis.size = tau.x.axis.size,
                  y.axis.size = tau.y.axis.size, title.size  = tau.title.size,
                  panel.title.size = tau.panel.title.size,
                  legend.size  = tau.legend.size, xlab=tau.xlab)
    g2 = plot_pexp_beta(samples, X=X, W=W, ncol.beta=ncol.betas, ncol.w=ncol.w,
                        nrow.w=nrow.w, smooth.line=smooth.line, 
                        pred.pexp.beta= pred.pexp.beta, title=title.beta, 
                        x.axis.size = beta.x.axis.size, 
                        y.axis.size = beta.y.axis.size, 
                        title.size  = beta.title.size, 
                        panel.title.size = beta.panel.title.size, 
                        legend.size  = beta.legend.size) 
    
    g = ggpubr::ggarrange(plotlist=list(g2,g1), nrow=2)
    return(g)
}

Try the hdpGLM package in your browser

Any scripts or data that you put into this service are public.

hdpGLM documentation built on Oct. 13, 2023, 1:17 a.m.