tests/testthat/test-response_weights.R

test_that("can use response weights", {
    expect_silent(
        train(data=iris, response_name="Species",
              response_weights=c(0.5, 1, 0.1))
    )
})

test_that("get error when training regression forest", {
    expect_error(
        train(data=iris, response_name="Sepal.Length",
              response_weights=c(0.5, 1, 0.1)),
        "'response_weights' argument only valid for classification forests",
        fixed=TRUE
    )
})

test_that("prediction accuracy for minority class increases with higher weight", {
    set.seed(3)
    n <- 100
    x <- rnorm(n)
    beta0 <- -3
    beta <- 1
    y <- factor(rbinom(n, 1, plogis(beta0 + beta * x)))
    dat <- data.frame(y=y, x)
    is_minor <- dat$y == "1"

    set.seed(42)
    rf <- train(data=dat, response_name="y", min_split_n_sample=50,
                response_weights=c(1, 1))
    pred <- predict(rf, newdata=dat[is_minor,])
    accuracy_minor <- mean(pred$values == 1, na.rm=TRUE)

    set.seed(42)
    rf_wtd <- train(data=dat, response_name="y", min_split_n_sample=50,
                    response_weights=c(0.01, 0.99))
    pred_wtd <- predict(rf_wtd, newdata=dat[is_minor,])
    accuracy_minor_wtd <- mean(pred_wtd$values == 1, na.rm=TRUE)

    expect_gt(accuracy_minor_wtd, accuracy_minor)
})

test_that("get error when response weights of wrong size", {
    expect_error(
        train(data=iris, response_name="Species", response_weights=c(0.5, 1)),
        "Number of response weights not equal to number of classes.",
        fixed=TRUE
    )
})

test_that("get error when response weights are NA", {
    expect_error(
        train(data=iris, response_name="Species",
              response_weights=c(0.5, 1, NA)),
        "missing value where TRUE/FALSE needed",
        fixed=TRUE
    )
})

test_that("get error when response weights are not numeric", {
    expect_error(
        train(data=iris, response_name="Species",
              response_weights = c(0.5, 1, "a")),
        paste("Invalid value for 'response_weights'. Please give a vector of",
              "non-negative values."),
        fixed=TRUE
    )
})

Try the literanger package in your browser

Any scripts or data that you put into this service are public.

literanger documentation built on Sept. 30, 2024, 9:15 a.m.