tests/testthat/test-mlclassificationknn.R

context("Machine Learning KNN Classification")

# Test fixed model #########################################################
options <- initMlOptions("mlClassificationKnn")
options$holdoutData <- "holdoutManual"
options$modelOptimization <- "manual"
options$modelValid <- "validationManual"
options$predictionsColumn <- ""
options$predictors <- c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width")
options$predictors.types <- rep("scale", 4)
options$savePath <- ""
options$setSeed <- TRUE
options$target <- "Species"
options$target.types <- "nominal"
options$testIndicatorColumn <- ""
options$testSetIndicatorVariable <- ""
options$dataSplitPlot <- FALSE
options$confusionTable <- FALSE
set.seed(1)
results <- jaspTools::runAnalysis("mlClassificationKnn", "iris.csv", options)

table <- results[["results"]][["classificationTable"]][["data"]]
jaspTools::expect_equal_tables(table,
		list("Euclidean", 30, 120, 3, 0.9, "rectangular"))

# Test optimized model #########################################################
options <- initMlOptions("mlClassificationKnn")
options$addPredictions <- FALSE
options$addIndicator <- FALSE
options$andrewsCurve <- TRUE
options$predictionsColumn <- ""
options$classProportionsTable <- TRUE
options$savePath <- ""
options$saveModel <- FALSE
options$holdoutData <- "holdoutManual"
options$modelOptimization <- "optimized"
options$modelValid <- "validationManual"
options$noOfFolds <- 5
options$errorVsKPlot <- TRUE
options$weightsPlot <- TRUE
options$predictors <- c("Alcohol", "Malic", "Ash", "Alcalinity", "Magnesium", "Phenols",
                           "Flavanoids", "Nonflavanoids", "Proanthocyanins", "Color",
                           "Hue", "Dilution", "Proline")
options$predictors.types <- rep("scale", length(options$predictors))
options$rocCurve <- TRUE
options$setSeed <- TRUE
options$target <- "Type"
options$target.types <- "nominal"
options$testDataManual <- 0.2
options$testIndicatorColumn <- ""
options$testSetIndicatorVariable <- ""
options$validationDataManual <- 0.2
options$validationMeasures <- TRUE
options$tableShap <- TRUE
options$fromIndex <- 1
options$toIndex <- 5
options$featureImportanceTable <- TRUE
set.seed(1)
results <- jaspTools::runAnalysis("mlClassificationKnn", "wine.csv", options)


test_that("Andrews Curves Plot matches", {
  plotName <- results[["results"]][["andrewsCurve"]][["data"]]
  testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]]
  jaspTools::expect_equal_plots(testPlot, "andrews-curves-plot")
})

test_that("Class Proportions table results match", {
  table <- results[["results"]][["classProportionsTable"]][["data"]]
  jaspTools::expect_equal_tables(table,
                      list(0.331460674157303, 1, 0.371428571428571, 0.324561403508772, 0.310344827586207,
                           0.398876404494382, 2, 0.4, 0.368421052631579, 0.517241379310345,
                           0.269662921348315, 3, 0.228571428571429, 0.307017543859649,
                           0.172413793103448))
})

test_that("K-Nearest Neighbors Classification table results match", {
  table <- results[["results"]][["classificationTable"]][["data"]]
  jaspTools::expect_equal_tables(table,
                      list("Euclidean", 1, 35, 114, 29, 0.942857142857143, 0.896551724137931,
                           "rectangular"))
})

test_that("Confusion Matrix table results match", {
  table <- results[["results"]][["confusionTable"]][["data"]]
  jaspTools::expect_equal_tables(table,
                      list("Observed", 1, 13, 0, 0, "", 2, 1, 12, 1, "", 3, 0, 0, 8))
})

test_that("Data Split plot matches", {
  plotName <- results[["results"]][["plotDataSplit"]][["data"]]
  testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]]
  jaspTools::expect_equal_plots(testPlot, "data-split")
})

test_that("Classification Accuracy Plot matches", {
  plotName <- results[["results"]][["errorVsKPlot"]][["data"]]
  testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]]
  jaspTools::expect_equal_plots(testPlot, "classification-accuracy-plot")
})

test_that("Rectangular Weight Function plot matches", {
	plotName <- results[["results"]][["weightsPlot"]][["data"]]
	testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]]
	jaspTools::expect_equal_plots(testPlot, "rectangular-weight-function")
})

test_that("ROC Curves Plot matches", {
  plotName <- results[["results"]][["rocCurve"]][["data"]]
  testPlot <- results[["state"]][["figures"]][[plotName]][["obj"]]
  jaspTools::expect_equal_plots(testPlot, "roc-curves-plot")
})

test_that("Evaluation Metrics table results match", {
  table <- results[["results"]][["validationMeasures"]][["data"]]
  jaspTools::expect_equal_tables(table,
   list(0.971428571428571, 0.977272727272727, 0.962962962962963, 0.0714285714285714,
        0, 0, 0.0454545454545455, 1, 0.941468871691272, 1, 0.928571428571429,
        1, 0.4, 13, 0.954545454545455, 6.5, 0.942857142857143, 0.928571428571429,
        0.923076923076923, 0, 0.142857142857143, 0.0869565217391304,
        0, 2, 0.884651736929383, 0.91304347826087, 1, 0.857142857142857,
        0.342857142857143, 14, 1, 6, 0.971428571428571, 0.981481481481482,
        0.941176470588235, 0.111111111111111, 0, 0, 0.037037037037037,
        3, 0.925184888651615, 1, 0.888888888888889, 1, 0.257142857142857,
        8, 0.962962962962963, 4, 0.961904761904762, 0.962441879108546,
        0.942028777322895, 0.0608465608465608, 0.0476190476190476, 0.0289855072463768,
        0.0274971941638608, "Average / Total", 0.91710183242409, 0.971014492753623,
        0.948072562358277, 0.942857142857143, 1, 35, 0.972502805836139,
        5.5
        ))
})

test_that("Additive Explanations for Predictions of Test Set Cases table results match", {
	table <- results[["results"]][["tableShap"]][["data"]]
	jaspTools::expect_equal_tables(table,
		list(-0.0175438596491228, 0.0614035087719298, 0.0350877192982456, 0.0789473684210527,
			 0.0526315789473685, 0.087719298245614, 0.219298245614035, 0.105263157894737,
			 -0.00877192982456138, 0.0175438596491229, 0.043859649122807,
			 0, 0, 0.324561403508772, 1, "1 (1)", 0, 0.18421052631579, 0.0350877192982456,
			 0.0175438596491229, 0, 0.149122807017544, 0, 0.0175438596491229,
			 0, 0.0964912280701754, 0.0614035087719298, -0.0175438596491229,
			 0.131578947368421, 0.324561403508772, 2, "1 (1)", 0.0526315789473684,
			 0.228070175438597, 0.0175438596491229, 0, 0, 0.0350877192982456,
			 0.140350877192982, 0, 0.0964912280701754, 0, 0.043859649122807,
			 0, 0.0614035087719298, 0.324561403508772, 3, "1 (1)", 0.0175438596491229,
			 0.0964912280701754, 0.0263157894736842, 0, 0, 0, 0.210526315789474,
			 0, 0.0701754385964912, 0, 0.00877192982456143, 0, 0.245614035087719,
			 0.324561403508772, 4, "1 (1)", -0.0175438596491229, 0.192982456140351,
			 0.0263157894736842, 0, 0, 0.175438596491228, 0, 0.0350877192982456,
			 0.00877192982456143, 0.0175438596491228, 0.0964912280701755,
			 0.00877192982456143, 0.131578947368421, 0.324561403508772, 5,
			 "1 (1)"))
})

# Tests specific for windows and linux
test_that("Feature Importance Metrics table results match", {
	testthat::skip_on_os("mac")
	table <- results[["results"]][["featureImportanceTable"]][["data"]]
	jaspTools::expect_equal_tables(table,
		list(35.3677070283885, "Proline", 12.7102697133271, "Alcohol", 11.7892356761295,
			 "Ash", 5.71041103062523, "Hue", 5.52620422318571, "Color", 4.42096337854857,
			 "Alcalinity", 4.23675657110904, "Malic", 2.76310211159285, "Dilution",
			 0.921034037197618, "Phenols", 0.921034037197618, "Nonflavanoids",
			 0.368413614879047, "Magnesium", 0.368413614879047, "Proanthocyanins",
			 0, "Flavanoids"))
})

# Tests specific for mac
test_that("Feature Importance Metrics table results match", {
	testthat::skip_on_os(c("windows", "linux"))
	table <- results[["results"]][["featureImportanceTable"]][["data"]]
	jaspTools::expect_equal_tables(table,
		list(35.3677070283885, "Proline", 12.7102697133271, "Alcohol", 11.7892356761295,
			 "Ash", 5.71041103062523, "Hue", 5.52620422318571, "Color", 4.42096337854857,
			 "Alcalinity", 4.23675657110904, "Malic", 2.76310211159285, "Dilution",
			 0.921034037197618, "Nonflavanoids", 0.921034037197618, "Phenols",
			 0.368413614879047, "Proanthocyanins", 0.368413614879047, "Magnesium",
			 0, "Flavanoids"))
})
jasp-stats/jaspMachineLearning documentation built on April 5, 2025, 3:52 p.m.