R/gen_pivot_data.R

#' Generate pivot data set
#'
#' Generates a data set with pivot and non-pivot features for several domains.
#' Pivot features are features that have the same distribution across domains.
#' Non-pivot features preserve the class relationships but distribution means
#' have been shifted across domains (use the `plot` method to observe this).
#'
#' This function outputs a balanced data set (same number of observations for
#' each class).
#'
#' @param n_nonpivots Number of non-pivot features.
#' @param n_pivots Number of pivot features.
#' @param n_domains Number of domains.
#' @param n_classes Number of possible classes.
#' @param n Number of observations. This is adjusted to the nearest number to
#'   allow for a balanced data set.
#' @param sd_class_means Standard deviation of class means. Smaller values will
#'   result in features with overlapping distributions.
#' @param sd_np_means Standard deviation of the non-pivot feature means. This
#'   controls the distribution shift across domains for non-pivot features.
#' @param sd_obs Standard deviation of the observations.
#'
#' @return `gen_pivot_data` returns an object of type "pivot_data" and
#'   "data.frame".
#'
#'   The function `plot` produces a plot of domain densities facetted by pivot
#'   and non-pivot features.
#'
#' @examples
#' pivot_data <- gen_pivot_data(1, 1, 2, 2, 200)
#' plot(pivot_data)
#' require(ggplot2)
#' ggplot(pivot_data, aes(x = NP_Feature_1, y = P_Feature_1, colour = Class)) +
#' geom_point() +
#' facet_wrap(~Domain)
#'
#' @author Cameron Roach
#'
#' @importFrom magrittr "%>%"
#' @export
gen_pivot_data <- function(n_nonpivots, n_pivots, n_domains, n_classes, n,
                           sd_class_means = 1, sd_np_means = 1, sd_obs = 1) {
  nonpivots <- paste0("NP_Feature_", 1:n_nonpivots)
  pivots <- paste0("P_Feature_", 1:n_pivots)
  domains <- paste0("Domain_", 1:n_domains)
  classes <- paste0("Class_", LETTERS[1:n_classes])

  n_input <- n
  n <- round(n/(n_domains*n_classes))*n_domains*n_classes

  cat("Data being generated for", n, "observations to",
      "create balanced data set.\nUser input specified", n_input,
      "observations.")

  pivot_data <- data.frame(
    Domain = rep(domains, each = n/n_domains),
    Class = rep(classes, times = n/n_classes)
  )

  # Simulate pivot features. For each class and feature, simulate observations
  # from a randomly selected distribution.
  if (n_pivots >0 ) {
    for (iP in pivots) {
      for (iC in classes) {
        class_mean <- stats::rnorm(1, 0, sd_class_means)
        idx <- with(pivot_data, Class == iC)
        pivot_data[idx, iP] <- stats::rnorm(n/n_classes, class_mean, sd_obs)
      }
    }
  }

  # Simulate non-pivot features. For each class, feature AND DOMAIN, simulate
  # observations from a randomly selected distribution.
  if (n_nonpivots > 0) {
    domain_shifts <- data.frame(
      Domain = domains,
      Shift = stats::rnorm(n_domains, 0, sd_np_means)
    )
    for (iNP in nonpivots) {
      for (iC in classes) {
        class_mean <- stats::rnorm(1, 0, sd_class_means)
        for (iD in domains) {
          domain_shift <- domain_shifts[domain_shifts$Domain == iD, "Shift"]

          idx <- with(pivot_data, Class == iC & Domain == iD)
          pivot_data[idx, iNP] <- stats::rnorm(n/n_classes/n_domains,
                                               class_mean + domain_shift, sd_obs)
        }
      }
    }
  }

  class(pivot_data) <- c("pivot_data", "data.frame")
  return(pivot_data)
}

#' Plot density functions for pivot_data object
#'
#' Produces a density plot for each domain facetted by pivot and non-pivot
#' features.
#'
#' @param x `pivot_data` object.
#' @param ... Arguments to be passed to methods.
#'
#' @return Returns a "ggplot" object showing density plots for the `pivot_data` object.
#' @export
#'
#' @author Cameron Roach
plot.pivot_data <- function(x, ...) {
  p <- x %>%
    tidyr::gather(Feature, Value, dplyr::contains("Feature")) %>%
    ggplot2::ggplot(ggplot2::aes(x = Value, fill = Class)) +
    ggplot2::geom_density(alpha = 0.3) +
    ggplot2::facet_grid(Domain ~ Feature) +
    ggplot2::labs(title = "Density plots for simulated data",
                  subtitle = "Data simulated using gen_pivot_data",
                  y = "Density")

  return(p)
}

#' Check if pivot_data object
#'
#' Function to check if an object has class "pivot_data".
#'
#' @param x Any `R` object.
#'
#' @return `is.pivot_data` returns `TRUE` if its argument is an object with
#'   class "pivot_data" and `FALSE` otherwise.
#' @export
#'
#' @author Cameron Roach
is.pivot_data <- function(x) {
  inherits(x, "pivot_data")
}
camroach87/semisupervisr documentation built on May 13, 2019, 11:04 a.m.