tests/testthat/test_mixture.R

test_that("continuous mixture variables can be sampled from", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1, 3)
  x <- mixture(normal(0, 1),
    normal(0, 2),
    normal(0, 3),
    weights = weights
  )

  sample_distribution(x)
})

test_that("fixed continuous mixture distributions can be sampled from", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1, 3)
  obs <- rnorm(100, 0, 2)
  distribution(obs) <- mixture(normal(0, 1),
    normal(0, 2),
    normal(0, 3),
    weights = weights
  )

  sample_distribution(weights)
})


test_that("fixed discrete mixture distributions can be sampled from", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1, 3)
  obs <- rbinom(100, 1, 0.5)
  distribution(obs) <- mixture(bernoulli(0.1),
    bernoulli(0.5),
    bernoulli(0.9),
    weights = weights
  )

  sample_distribution(weights)
})

test_that("mixtures of fixed and continuous distributions errors", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1, dim = 2)
  expect_snapshot_error(
    mixture(
      bernoulli(0.5),
      normal(0, 1),
      weights = weights
    )
  )
})


test_that("mixtures of multivariate and univariate errors", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1, dim = 2)
  expect_snapshot_error(
    mixture(
      multivariate_normal(zeros(1, 3), diag(3)),
      normal(0, 1, dim = c(1, 3)),
      weights = weights
    )
  )
})

test_that("mixtures of supports errors", {
  skip_if_not(check_tf_version())
  

  weights <- c(0.5, 0.5)

  # due to truncation
  expect_snapshot_error(
    mixture(
      normal(0, 1, truncation = c(0, Inf)),
      normal(0, 1),
      weights = weights
    )
  )

  # due to bounds
  expect_snapshot_error(
    mixture(
      lognormal(0, 1),
      normal(0, 1),
      weights = weights
    )
  )
})

test_that("incorrectly-shaped weights errors", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1, dim = c(1, 2))
  expect_snapshot_error(
    mixture(
      normal(0, 1),
      normal(0, 2),
      weights = weights
    )
  )
})

test_that("mixtures with insufficient distributions errors", {
  skip_if_not(check_tf_version())
  

  weights <- uniform(0, 1)

  expect_snapshot_error(
    mixture(
      normal(0, 2),
      weights = weights
    )
  )

  expect_snapshot_error(
    mixture(weights = weights)
  )

})

test_that("mixture of normals has correct density", {
  skip_if_not(check_tf_version())
  

  mix_greta <- function(means, sds, weights, dim) {
    mixture(normal(means[1], sds[1], dim),
      normal(means[2], sds[2], dim),
      normal(means[3], sds[3], dim),
      weights = weights
    )
  }

  mix_r <- function(x, means, sds, weights) {
    densities <- matrix(NA,
      nrow = length(x),
      ncol = length(means)
    )
    for (i in seq_along(means)) {
      densities[, i] <- dnorm(x, means[i], sds[i])
    }

    densities <- sweep(densities, 2, weights, "*")
    rowSums(densities)
  }

  params <- list(
    means = c(-2, 2, 5),
    sds = c(3, 0.5, 1),
    weights = c(0.3, 0.6, 0.1)
  )

  compare_distribution(mix_greta,
    mix_r,
    parameters = params,
    x = rnorm(100, -2, 3)
  )
})

test_that("mixture of truncated normals has correct density", {
  skip_if_not(check_tf_version())
  

  mix_greta <- function(means, sds, weights, dim) {
    mixture(normal(means[1], sds[1], dim, truncation = c(0, Inf)),
      normal(means[2], sds[2], dim, truncation = c(0, Inf)),
      normal(means[3], sds[3], dim, truncation = c(0, Inf)),
      weights = weights
    )
  }

  mix_r <- function(x, means, sds, weights) {
    densities <- matrix(NA,
      nrow = length(x),
      ncol = length(means)
    )

    for (i in seq_along(means)) {
      densities[, i] <- truncdist::dtrunc(
        x = x,
        spec = "norm",
        a = 0,
        b = Inf,
        mean = means[i],
        sd = sds[i]
      )
    }

    densities <- sweep(densities, 2, weights, "*")
    rowSums(densities)
  }

  params <- list(
    means = c(-2, 2, 5),
    sds = c(3, 0.5, 1),
    weights = c(0.3, 0.6, 0.1)
  )

  compare_distribution(mix_greta,
    mix_r,
    parameters = params,
    x = abs(rnorm(100, -2, 3))
  )
})

test_that("mixture of Poissons has correct density", {
  skip_if_not(check_tf_version())
  

  mix_greta <- function(rates, weights, dim) {
    mixture(poisson(rates[1], dim),
      poisson(rates[2], dim),
      poisson(rates[3], dim),
      weights = weights
    )
  }

  mix_r <- function(x, rates, weights) {
    densities <- matrix(NA,
      nrow = length(x),
      ncol = length(rates)
    )
    for (i in seq_along(rates)) {
      densities[, i] <- dpois(x, rates[i])
    }

    densities <- sweep(densities, 2, weights, "*")
    rowSums(densities)
  }

  params <- list(
    rates = c(0.1, 2, 5),
    weights = c(0.3, 0.6, 0.1)
  )

  compare_distribution(mix_greta,
    mix_r,
    parameters = params,
    x = rpois(100, 3)
  )
})

test_that("mixture of normals with varying weights has correct density", {
  skip_if_not(check_tf_version())
  

  mix_greta <- function(means, sds, weights, dim) {
    mixture(normal(means[1], sds[1], dim),
      normal(means[2], sds[2], dim),
      normal(means[3], sds[3], dim),
      weights = weights
    )
  }

  mix_r <- function(x, means, sds, weights) {
    out_dim <- dim(x)
    densities <- array(NA, c(length(means), prod(out_dim)))
    for (i in seq_along(means)) {
      densities[i, ] <- dnorm(x, means[i], sds[i])
    }
    dim(weights) <- dim(densities)
    densities <- colSums(densities * weights)
    dim(densities) <- out_dim
    densities
  }

  dim <- c(10, 5, 4)
  n <- prod(dim)

  weights <- matrix(runif(n * 3), 3, n)
  weights <- sweep(weights, 2, colSums(weights), "/")
  dim(weights) <- c(3, dim)

  x <- array(rnorm(n, -2, 3), dim)

  params <- list(
    means = c(-2, 2, 5),
    sds = c(3, 0.5, 1),
    weights = weights
  )

  compare_distribution(mix_greta,
    mix_r,
    parameters = params,
    x = x,
    dim = dim
  )
})

Try the greta package in your browser

Any scripts or data that you put into this service are public.

greta documentation built on Sept. 8, 2022, 5:10 p.m.