R/simulate_mixture_cube.R

Defines functions simulate_mixture_cube

Documented in simulate_mixture_cube

#' @title Simulate a mixture cube to test \code{CVtreeMLE} against simulated
#' ground-truth.
#' @description Simulate a mixture cube. This creates three correlated mixture
#' variables that are associated with two confounders W1 and W2. First mixtures
#' are generated from a multivariate normal.
#' A multinomial outcome is generated based on betas input for W1 and W2 -
#' associating each W with a part of the mixture cube. In each part of the
#' mixture cube, transform the multivariate
#' normal mixture to a uniform distribution, respecting the bounds for parts
#' of the cube. A three variable cube with one threshold per variable has 8
#' subspaces. An outcome is then generated
#' as a linear combination of different subspaces.
#'
#' @param n_obs Number of observations for which to generate data
#' @param splits Vector indicating where thresholds should be placed for each
#' mixture variable
#' @param mins Vector indicating the minimum values for each mixture variable
#' @param maxs Vector indicating the maximum values for each mixture variable
#' @param mu Vector indicating the mean values for each mixture variable
#' @param sigma Matrix of the variance-covariance structure used to generate
#' the mixture variables
#' @param w1_betas Vector of betas that define the subspace probability
#' relationship with covariate W1
#' @param w2_betas Vector of betas that define the subspace probability
#' relationship with covariate W2
#' @param mix_subspace_betas Vector of betas that define the subspace
#'  probabilities
#' @param subspace_assoc_strength_betas The outcome Y generated by each
#' partition of the mixture cube
#' @param marginal_impact_betas Vector of betas that define the marginal
#'  impact each mixture variable has
#' @param eps_sd Random error included in the generation of Y
#' @param binary TRUE/FALSE depending on if the outcome should be binary
#' @importFrom data.table rbindlist
#' @importFrom dplyr group_by
#' @importFrom stats as.formula glm p.adjust plogis predict
#' @importFrom stats qlogis qnorm qunif rnorm runif
#' @importFrom dplyr mutate
#' @importFrom MASS mvrnorm
#' @importFrom purrr rbernoulli
#' @importFrom rlang :=
#' @importFrom stats sd
#' @return obs: A data frame of the simulated data for the mixture cube.

#' @export


## levels reference for our mixture space:
# 0: All mixtures lower than specified thresholds
# 1: M1 is higher but M2 and M3 are lower
# 2: M2 is higher but M1 and M3 are lower
# 3: M1 and M2 are higher and M3 is lower
# 4: M3 is higher and M1 and M2 are lower
# 5: M1 and M3 are higher and M2 is lower
# 6: M2 and M3 are higher and M1 is lower
# 7: All mixtures are higher than thresholds


simulate_mixture_cube <- function(n_obs = 500,
                                  splits = c(0.99, 2.0, 2.5),
                                  mins = c(0, 0, 0),
                                  maxs = c(3, 4, 5),
                                  mu = c(0, 0, 0),
                                  sigma = matrix(
                                    c(
                                      1, 0.5, 0.8, 0.5,
                                      1, 0.7, 0.8, 0.7, 1
                                    ),
                                    nrow = 3, ncol = 3
                                  ),
                                  w1_betas = c(
                                    0.0, 0.01, 0.03, 0.06, 0.1,
                                    0.05, 0.2, 0.04
                                  ),
                                  w2_betas = c(
                                    0.0, 0.04, 0.01, 0.07, 0.15,
                                    0.1, 0.1, 0.04
                                  ),
                                  mix_subspace_betas = c(
                                    0.00, 0.08, 0.05, 0.01,
                                    0.05, 0.033, 0.07,
                                    0.09
                                  ),
                                  subspace_assoc_strength_betas = c(
                                    1, 1, 1, 1,
                                    1, 1, 1, 7
                                  ),
                                  marginal_impact_betas = c(0, 0, 0),
                                  eps_sd = 0.01,
                                  binary = FALSE) {
  barck_trans <- function(x, max, min) {
    x * (max - min) + min
  }


  rawvars <- MASS::mvrnorm(n = n_obs, mu = mu, Sigma = sigma)
  # CDF of these variables for mixture
  pvars <- stats::pnorm(rawvars)

  ## create a covariate

  age <- rnorm(n_obs, 37, 3)
  bmi <- rnorm(n_obs, 20, 1)
  sex <- as.numeric(purrr::rbernoulli(n_obs, 0.5))
  covars <- data.frame(age, bmi, sex)
  ## probabilities
  b0i <- round(rnorm(8, 0.3, 0.01), 2)
  b1i <- round(rnorm(8, 0.4, 0.01), 2)
  b2i <- round(rnorm(8, 0.5, 0.01), 2)
  b3i <- round(rnorm(8, 0.5, 0.01), 2)

  probs_list <- c()

  for (i in seq(nrow(covars))) {
    age <- covars$age[i]
    bmi <- covars$bmi[i]
    sex <- covars$sex[i]

    gen_denominator <- function(index, b0i, b1i, b3i, covars) {
      1 + exp(b0i[index] + (b1i[index] * age) + (b2i[index] * bmi) +
        (b3i[index] * sex))
    }

    gen_probs <- function(index, b0i, b1i, b3i, covars, denominator) {
      exp(b0i[index] + (b1i[index] * age) + (b2i[index] * bmi) +
        (b3i[index] * sex)) / (1 + denominator)
    }

    denominator <- sum(sapply(seq(from = 1, to = 8),
      FUN = gen_denominator, b0i, b1i, b3i, covars
    ))

    probs <- sapply(seq(from = 1, to = 8),
      FUN = gen_probs, b0i, b1i, b3i, covars, denominator
    )

    probs_list[[i]] <- probs
  }

  probs_df <- as.data.frame(do.call(rbind, probs_list))
  colnames(probs_df) <- paste0("p", seq_len(ncol(probs_df)))

  probs <- probs_df %>% dplyr::select(p1:p8)

  res <- probs_df %>%
    dplyr::mutate(rcat = Hmisc::rMultinom(probs, 1))

  res <- as.data.frame(res)

  mixture_section_indicator <- expand.grid(c(0, 1), c(0, 1), c(0, 1))
  colnames(mixture_section_indicator) <- c("M1", "M2", "M3")

  ms <- as.data.frame(matrix(data = NA, ncol = 3, nrow = n_obs))
  colnames(ms) <- c("M1", "M2", "M3")

  for (i in seq_len(nrow(mixture_section_indicator))) {
    ## iteration through the subspaces
    mix_space <- mixture_section_indicator[i, ]

    ## 0 or 1 for for each mixture
    m1_01 <- mix_space[1]
    m2_02 <- mix_space[2]
    m3_03 <- mix_space[3]

    ## set high and low for M1
    if (m1_01 == 0) {
      m1_min <- mins[1]
      m1_max <- splits[1]
    } else {
      m1_min <- splits[1]
      m1_max <- maxs[1]
    }

    ## set high and low for M2
    if (m2_02 == 0) {
      m2_min <- mins[2]
      m2_max <- splits[2]
    } else {
      m2_min <- splits[2]
      m2_max <- maxs[2]
    }

    ## set high and low for M3
    if (m3_03 == 0) {
      m3_min <- mins[3]
      m3_max <- splits[3]
    } else {
      m3_min <- splits[3]
      m3_max <- maxs[3]
    }

    unifvars <- qunif(pvars, min = 0, max = 1)

    m1_sec <- barck_trans(unifvars[, 1], min = m1_min, max = m1_max)
    m2_sec <- barck_trans(unifvars[, 2], min = m2_min, max = m2_max)
    m3_sec <- barck_trans(unifvars[, 3], min = m3_min, max = m3_max)

    subspace_data <- cbind(m1_sec, m2_sec, m3_sec)

    ms[res$rcat == paste("p", i, sep = ""), ] <- subspace_data[res$rcat ==
      paste("p",
        i,
        sep =
          ""
      ), ]
  }

  unifvars <- qunif(pvars, min = 0, max = 1)

  m1_marg <- barck_trans(unifvars[, 1], min = mins[1], max = maxs[1])
  m2_marg <- barck_trans(unifvars[, 2], min = mins[2], max = maxs[2])
  m3_marg <- barck_trans(unifvars[, 3], min = mins[3], max = maxs[3])

  m1_marg <- ifelse(m1_marg > splits[1], 1, 0)
  m2_marg <- ifelse(m2_marg > splits[2], 1, 0)
  m3_marg <- ifelse(m3_marg > splits[3], 1, 0)

  covars$age <- (covars$age - mean(covars$age)) / sd(covars$age)
  covars$bmi <- (covars$bmi - mean(covars$bmi)) / sd(covars$bmi)
  covars$sex <- (covars$sex - mean(covars$sex)) / sd(covars$sex)

  if (binary == TRUE) {
    y <-
      stats::plogis(subspace_assoc_strength_betas[1] +
        subspace_assoc_strength_betas[2] * as.numeric(res$rcat == "p2") +
        subspace_assoc_strength_betas[3] * as.numeric(res$rcat == "p3") +
        subspace_assoc_strength_betas[4] * as.numeric(res$rcat == "p4") +
        subspace_assoc_strength_betas[5] * as.numeric(res$rcat == "p5") +
        subspace_assoc_strength_betas[6] * as.numeric(res$rcat == "p6") +
        subspace_assoc_strength_betas[7] * as.numeric(res$rcat == "p7") +
        subspace_assoc_strength_betas[8] * as.numeric(res$rcat == "p8") +
        marginal_impact_betas[1] * m1_marg +
        marginal_impact_betas[2] * m2_marg +
        marginal_impact_betas[3] * m2_marg +
        covars$age +
        covars$sex +
        rnorm(length(res$rcat), mean = 0, sd = eps_sd))

    y <- ifelse(y > 0.50, 1, 0)
  } else {
    y <-
      ## section for mixture subspaces
      subspace_assoc_strength_betas[1] +
      subspace_assoc_strength_betas[2] * as.numeric(res$rcat == "p2") +
      subspace_assoc_strength_betas[3] * as.numeric(res$rcat == "p3") +
      subspace_assoc_strength_betas[4] * as.numeric(res$rcat == "p4") +
      subspace_assoc_strength_betas[5] * as.numeric(res$rcat == "p5") +
      subspace_assoc_strength_betas[6] * as.numeric(res$rcat == "p6") +
      subspace_assoc_strength_betas[7] * as.numeric(res$rcat == "p7") +
      subspace_assoc_strength_betas[8] * as.numeric(res$rcat == "p8") +
      marginal_impact_betas[1] * m1_marg +
      marginal_impact_betas[2] * m2_marg +
      marginal_impact_betas[3] * m2_marg +
      covars$age +
      covars$sex +
      rnorm(length(res$rcat), mean = 0, sd = eps_sd)
  }

  obs <- as.data.frame(cbind(covars, ms, y))

  return(obs)
}
blind-contours/CVtreeMLE documentation built on June 22, 2024, 8:53 p.m.