#' Generate pivot data set
#'
#' Generates a data set with pivot and non-pivot features for several domains.
#' Pivot features are features that have the same distribution across domains.
#' Non-pivot features preserve the class relationships but distribution means
#' have been shifted across domains (use the `plot` method to observe this).
#'
#' This function outputs a balanced data set (same number of observations for
#' each class).
#'
#' @param n_nonpivots Number of non-pivot features.
#' @param n_pivots Number of pivot features.
#' @param n_domains Number of domains.
#' @param n_classes Number of possible classes.
#' @param n Number of observations. This is adjusted to the nearest number to
#' allow for a balanced data set.
#' @param sd_class_means Standard deviation of class means. Smaller values will
#' result in features with overlapping distributions.
#' @param sd_np_means Standard deviation of the non-pivot feature means. This
#' controls the distribution shift across domains for non-pivot features.
#' @param sd_obs Standard deviation of the observations.
#'
#' @return `gen_pivot_data` returns an object of type "pivot_data" and
#' "data.frame".
#'
#' The function `plot` produces a plot of domain densities facetted by pivot
#' and non-pivot features.
#'
#' @examples
#' pivot_data <- gen_pivot_data(1, 1, 2, 2, 200)
#' plot(pivot_data)
#' require(ggplot2)
#' ggplot(pivot_data, aes(x = NP_Feature_1, y = P_Feature_1, colour = Class)) +
#' geom_point() +
#' facet_wrap(~Domain)
#'
#' @author Cameron Roach
#'
#' @importFrom magrittr "%>%"
#' @export
gen_pivot_data <- function(n_nonpivots, n_pivots, n_domains, n_classes, n,
sd_class_means = 1, sd_np_means = 1, sd_obs = 1) {
nonpivots <- paste0("NP_Feature_", 1:n_nonpivots)
pivots <- paste0("P_Feature_", 1:n_pivots)
domains <- paste0("Domain_", 1:n_domains)
classes <- paste0("Class_", LETTERS[1:n_classes])
n_input <- n
n <- round(n/(n_domains*n_classes))*n_domains*n_classes
cat("Data being generated for", n, "observations to",
"create balanced data set.\nUser input specified", n_input,
"observations.")
pivot_data <- data.frame(
Domain = rep(domains, each = n/n_domains),
Class = rep(classes, times = n/n_classes)
)
# Simulate pivot features. For each class and feature, simulate observations
# from a randomly selected distribution.
if (n_pivots >0 ) {
for (iP in pivots) {
for (iC in classes) {
class_mean <- stats::rnorm(1, 0, sd_class_means)
idx <- with(pivot_data, Class == iC)
pivot_data[idx, iP] <- stats::rnorm(n/n_classes, class_mean, sd_obs)
}
}
}
# Simulate non-pivot features. For each class, feature AND DOMAIN, simulate
# observations from a randomly selected distribution.
if (n_nonpivots > 0) {
domain_shifts <- data.frame(
Domain = domains,
Shift = stats::rnorm(n_domains, 0, sd_np_means)
)
for (iNP in nonpivots) {
for (iC in classes) {
class_mean <- stats::rnorm(1, 0, sd_class_means)
for (iD in domains) {
domain_shift <- domain_shifts[domain_shifts$Domain == iD, "Shift"]
idx <- with(pivot_data, Class == iC & Domain == iD)
pivot_data[idx, iNP] <- stats::rnorm(n/n_classes/n_domains,
class_mean + domain_shift, sd_obs)
}
}
}
}
class(pivot_data) <- c("pivot_data", "data.frame")
return(pivot_data)
}
#' Plot density functions for pivot_data object
#'
#' Produces a density plot for each domain facetted by pivot and non-pivot
#' features.
#'
#' @param x `pivot_data` object.
#' @param ... Arguments to be passed to methods.
#'
#' @return Returns a "ggplot" object showing density plots for the `pivot_data` object.
#' @export
#'
#' @author Cameron Roach
plot.pivot_data <- function(x, ...) {
p <- x %>%
tidyr::gather(Feature, Value, dplyr::contains("Feature")) %>%
ggplot2::ggplot(ggplot2::aes(x = Value, fill = Class)) +
ggplot2::geom_density(alpha = 0.3) +
ggplot2::facet_grid(Domain ~ Feature) +
ggplot2::labs(title = "Density plots for simulated data",
subtitle = "Data simulated using gen_pivot_data",
y = "Density")
return(p)
}
#' Check if pivot_data object
#'
#' Function to check if an object has class "pivot_data".
#'
#' @param x Any `R` object.
#'
#' @return `is.pivot_data` returns `TRUE` if its argument is an object with
#' class "pivot_data" and `FALSE` otherwise.
#' @export
#'
#' @author Cameron Roach
is.pivot_data <- function(x) {
inherits(x, "pivot_data")
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.