tests/testthat/test-compute-gini.R

gini_impurity <- function (vals) {
 counts <- table(vals)
 total <- sum(counts)
 return(sum((counts/total) * (1 - counts/total)))
}

test_that(
 desc = "gini index matches expected answer",
 code = {

  n <- 100

  y <- matrix(rbinom(n, size = 1, prob = 1/2), ncol = 1)
  w <- rep(1, n)
  g <- rbinom(n, size = 1, prob = 1/3)
  vals_1 = factor(y[g==1], levels = c(0,1), labels = c("blue", "green"))
  vals_0 = factor(y[g==0], levels = c(0,1), labels = c("blue", "green"))

  gini_1 <- gini_impurity(vals = vals_1)
  gini_0 <- gini_impurity(vals = vals_0)

  target <- gini_1 * sum(w[g==1]) / sum(w) + gini_0 * sum(w[g==0]) / sum(w)

  cpp <- compute_gini_exported(y, w, g)

  expect_equal(target, cpp)

 }
)

Try the aorsf package in your browser

Any scripts or data that you put into this service are public.

aorsf documentation built on June 22, 2024, 10:31 a.m.