R/plotting-functions.R

Defines functions plot_simrel plot_covariance plot_beta

Documented in plot_beta plot_covariance plot_simrel

#' @title  Plotting Functions
#' @name plot_beta
#' @param obj A simrel object
#' @param base_theme Base ggplot theme to apply
#' @param lab_list List of labs arguments such as x, y, title, subtitle
#' @param theme_list List of theme arguments to apply in the plot
#' @return A plot of true regression coefficients for the simulated data
#' @import ggplot2
#' @examples
#' sobj <- multisimrel()
#' sobj %>%
#'     plot_beta(
#'         base_theme = ggplot2::theme_bw,
#'         lab_list = list(
#'             title = "Regression Coefficients",
#'             subtitle = "From Simulation",
#'             y = "True Regression Coefficients"
#'         ),
#'         theme_list = list(
#'             legend.position = "bottom"
#'         )
#'     )
#' @rdname plot_beta
#' @export
plot_beta <- function(obj, base_theme = theme_grey, lab_list = NULL, theme_list = NULL) {
    beta_df <- obj %>% tidy_beta()
    plt <- beta_df %>%
        modify_at("Response", as.factor) %>%
        ggplot(aes_string("Predictor", "BetaCoef",
                          color = "Response",
                          group = "Response")) +
        geom_hline(yintercept = 0,
                   color = "grey2",
                   linetype = 2) +
        geom_line() +
        geom_point() +
        base_theme()
    if (!is.null(lab_list)) {
        plt  <- plt + do.call(labs, lab_list)
    }
    if (!is.null(theme_list)) {
        plt <- plt + do.call(theme, theme_list)
    }
    return(plt)
}

#' @name plot_covariance
#' @title Plot Covariance between predictor (components) and response (components)
#' @param sigma_df A data.frame generated by tidy_sigma
#' @param lambda_df A data.frame generated by tidy_lambda
#' @param base_theme Base ggplot theme to apply
#' @param lab_list List of labs arguments such as x, y, title, subtitle
#' @param theme_list List of theme arguments to apply in the plot
#' @return A plot of true regression coefficients for the simulated data
#' @import ggplot2
#' @examples
#' sobj <- bisimrel(p = 12)
#' sigma_df <- sobj %>%
#'     cov_mat(which = "zy") %>%
#'     tidy_sigma() %>%
#'     abs_sigma()
#' lambda_df <- sobj %>%
#'     tidy_lambda()
#' plot_covariance(
#'     sigma_df,
#'     lambda_df,
#'     base_theme = ggplot2::theme_bw,
#'     lab_list = list(
#'         title = "Covariance between Response and Predictor Components",
#'         subtitle = "The bar represents the eigenvalues predictor covariance",
#'         y = "Absolute covariance",
#'         x = "Predictor Component",
#'         color = "Response Component"
#'     ),
#'     theme_list = list(
#'         legend.position = "bottom"
#'     )
#' )
#' @rdname plot_covariance
#' @export
#'
plot_covariance <- function(sigma_df, lambda_df = NULL, base_theme = theme_grey, lab_list = NULL, theme_list = NULL) {
    pjtr <- position_dodge(0.5)
    plt <- sigma_df %>%
        modify_at("Response", as.factor) %>%
        ggplot(aes_string(
            x = 'Predictor',
            y = 'Covariance',
            color = 'Response',
            group = 'Response'
        ))
    if (!is.null(lambda_df)) {
        plt <- plt +
            geom_bar(
                data = lambda_df,
                aes_string(x = 'Predictor', y = 'lambda'),
                inherit.aes = FALSE,
                fill = "lightgrey",
                stat = "identity"
            )
    }
    plt <- plt +
        geom_hline(yintercept = 0, color = "grey", linetype = "dashed") +
        geom_line(position = pjtr) +
        geom_point(position = pjtr) +
        base_theme()
    if (!is.null(lab_list)) {
        plt  <- plt + do.call(labs, lab_list)
    }
    if (!is.null(theme_list)) {
        plt <- plt + do.call(theme, theme_list)
    }
    return(plt)
}

#' @name plot_simrel
#' @title A wrapper function for a simrel object
#' @param obj A simrel object
#' @param ncomp Number of components to show in x-axis
#' @param which An integer specifying which simrel plot to obtain
#' @param layout A layout matrix for arranging the simrel plots
#' @param print.cov A boolean where to print covariance matrices
#' @param use_population A boolean specifying weather to get plot for population or sample
#' @param palette Name of color paletter compaticable with RColorBrewer
#' @param base_theme Base ggplot theme to apply 
#' @param lab_list List of labs arguments such as x, y, title, subtitle. A nested list if the argument which has length greater than 1.
#' @param theme_list List of theme arguments to apply in the plot. A nested list if the argument which has length greater than 1.
#' @return Simrel Plot(s)
#' @import ggplot2
#' @importFrom gridExtra grid.arrange
#' @importFrom scales pretty_breaks
#' @examples
#' sobj <- bisimrel(p = 12)
#' plot_simrel(sobj, layout = matrix(1:4, 2, 2))
#' @rdname plot_simrel
#' @export
#'
plot_simrel <- function(obj, ncomp = min(obj$p, obj$n, 20), which = c(1L:4L),
                        layout = NULL, print.cov = FALSE, use_population = TRUE,
                        palette = "Set1", base_theme = ggplot2::theme_grey, 
                        lab_list = NULL, theme_list = NULL) {
    if (is.null(base_theme)) base_theme <- ggplot2::theme_grey
    if (is.null(theme_list)) {
        theme_list <- list(
            legend.position = if (ncol(obj$Y) == 1) "none" else "bottom"
        )
    }
    plt1 <- expression({
        if (is.null(lab_list)) {
            lab_list <- list(
                x = "Predictor Variables",
                y = "Regression Coefficients",
                title = "True Regression Coefficients"
            )
        }
        obj %>% 
            plot_beta(base_theme = base_theme, 
                      lab_list = lab_list, theme_list = theme_list) +
            ggplot2::scale_color_brewer(palette = palette) +
            ggplot2::coord_cartesian(xlim = c(1, ncomp)) +
            ggplot2::scale_x_continuous(breaks = scales::pretty_breaks(ncomp))
    })
    plt2 <- expression({
        if (is.null(lab_list)) {
            lab_list <- list(
                x = "Components",
                y = "Absolute Covariences",
                title = "True Relevant Components Plot"
            )
        }
        type = obj$type
        sigma_df <- obj %>%
            cov_mat(which = if (type == "multivariate") "zw" else "zy") %>%
            tidy_sigma() %>% 
            abs_sigma()
        lambda_df <- obj %>%
            tidy_lambda()
        plot_covariance(sigma_df, lambda_df, base_theme = base_theme,
                        lab_list = lab_list, theme_list = theme_list) +
            ggplot2::scale_color_brewer(palette = palette) +
            ggplot2::coord_cartesian(xlim = c(1, ncomp)) +
            ggplot2::scale_x_continuous(breaks = scales::pretty_breaks(ncomp))
    })
    plt3 <- expression({
        if (is.null(lab_list)) {
            lab_list <- list(
                x = "Components",
                y = "Absolute Covariences",
                title = "Estimated Relevant Components Plot"
            )
        }
        sigma_df <- obj %>%
            cov_mat(which = "zy", use_population = use_population) %>%
            tidy_sigma() %>% 
            abs_sigma()
        lambda_df <- obj %>%
            tidy_lambda(use_population = use_population)
        plot_covariance(sigma_df, lambda_df, base_theme = base_theme,
                        lab_list = lab_list, theme_list = theme_list) +
            ggplot2::scale_color_brewer(palette = palette) +
            ggplot2::coord_cartesian(xlim = c(1, ncomp)) +
            ggplot2::scale_x_continuous(breaks = scales::pretty_breaks(ncomp))
    })
    plt4 <- expression({
        if (is.null(lab_list)) {
            lab_list <- list(
                x = "Predictors",
                y = "Absolute Covariences",
                title = "Estimated Covariance Plot"
            )
        }
        sigma_df <- obj %>%
            cov_mat(which = "xy", use_population = use_population) %>%
            tidy_sigma() %>% 
            abs_sigma()
        plot_covariance(sigma_df, base_theme = base_theme,
                        lab_list = lab_list, theme_list = theme_list) +
            ggplot2::scale_color_brewer(palette = palette) +
            ggplot2::coord_cartesian(xlim = c(1, ncomp)) +
            ggplot2::scale_x_continuous(breaks = scales::pretty_breaks(ncomp))
    })
    plt <- list(
        TrueBeta   = plt1,
        RelComp    = plt2,
        EstRelComp = plt3,
        EstCov     = plt4
    )
    if (length(which) == 1) {
        out <- eval(plt[[which]])
    } else {
        plts <- lapply(which, function(i) eval(plt[[i]]))
        names(plts) <- names(plt)[which]
        
        if (length(which) == 2 & is.null(layout)) layout <- matrix(c(1, 2), 2)
        if (length(which) > length(layout)) layout <- matrix(1:length(which), length(which))
        
        plts$layout_matrix <- layout
        out <- do.call(gridExtra::grid.arrange, plts)
    }
    if (length(which) == 1) return(out) else return(invisible(out))
}
simulatr/simrel documentation built on Nov. 19, 2022, 7:05 a.m.