tests/testthat/test_05_simplex_calculation.R

context("Check calculations")

testthat::test_that("Simplex identifies nearest neighbors correctly", {
    ts <- c(-0.056531409251883, 0.059223778257432, 5.24124928046977, -4.85399581474521,
            -0.46134818068973, 0.273317575696793, 0.801806230470337, -0.888891901824982,
            -0.202777622745051, 0.497565422757662, 5.10219324323769, -5.36826459373178,
            -0.17467165498718, 1.06545333399298, 1.97419279178678, -2.91448405223082,
            -0.179969867858605, 0.237962494942611, 1.47828327622468, -1.54267507064286,
            -0.180342027136338, 0.238919610831881, 1.06140368490958, -1.06522901782019,
            -0.214923527940395, 0.452847221462308, 2.13053391555372, -2.55145224744286,
            -0.0307653352327702, 1.1448014288826, -0.0675575239486375, -1.04711881585576,
            -0.00910890042051652, 0.726257323277433, 0.732271192186161, -1.35460378982395,
            -0.0322955446760023, 0.507606440290776, 3.73396587274012, -4.19686615950143,
            -0.0997201857962038, 0.753392632401029, 2.41347231553437, -3.03677401452137,
            -0.141112562089696, 0.446002103079665, 0.223768504955365, -0.615452831633047,
            -0.0216659723974975, 0.292246351104258, 0.20006105300258, -0.469596514211075,
            0.0422676544887819, 0.474264989176278, -0.0416811459395667, -0.53555712696719,
            0.118860281628173, 0.176335117268894, -0.10364820567334, -0.153572235117542,
            0.180339482186409, 0.0566876206447625, -0.140537892644139, 0.0252441742388871,
            0.340689505466622, 0.852833653689839, -1.07051231019616, -0.0937704380137284,
            0.460677118593916, 0.37444382348273, -0.83783628206217, -0.0154896108244113,
            1.34259279914848, -0.495978821807168, -0.472464634960208, -0.415481769949074,
            1.36767605087962, -0.891896943918948, -0.279228283931612, -0.148703043863421,
            2.04524590138255, -1.98431486665908, 0.0602356391036573, -0.0902730939678147,
            0.243344379963862, -0.074421904114315, -0.309150440565139, 0.43675531763949,
            0.178787692802827, 0.0799271040758849, -0.657946157906476, 1.14668210755046,
            -0.791665479471326, 0.482533897248175, -0.798737571552661, 0.439024256063545,
            0.177114631209318, 2.19942374686687, -2.9488856529422)
    
    # construct lagged block
    lag_block <- cbind(c(ts[2:length(ts)], NA), ts, c(NA, ts[1:(length(ts) - 1)]))
    t <- c(2:63, 65:99)
    
    # lib and pred portions
    lib_block <- cbind(t + 1, lag_block[t, ])
    pred_block <- cbind(65, lag_block[64, , drop = FALSE])
    
    block <- rbind(lib_block, pred_block)
    
    # make EDM forecast
    out <- block_lnlp(block,
                      lib = c(1, NROW(lib_block)),
                      pred = c(NROW(lib_block) + 1, NROW(lib_block) + 1),
                      first_column_time = TRUE,
                      tp = 0,
                      columns = c(2, 3), target_column = 1,
                      stats_only = FALSE, silent = TRUE)
    model_est <- out$model_output[[1]]$pred
    
    # manually calculate distances and neighbors
    dist_mat <- as.matrix(dist(block[, 3:4]))
    dist_vec <- dist_mat[NROW(dist_mat), ]
    dist_vec[length(dist_vec)] <- NA
    nn <- order(dist_vec)[1:3] # 3 closest neighbors
    weights <- exp(-dist_vec[nn] / dist_vec[nn[1]])
    est <- sum(weights * block[nn, 2]) / sum(weights) # weighted average
    
    testthat::expect_equal(model_est, est)
})

testthat::test_that("Simplex excludes ties in nearest neighbors correctly", {
    ## smallish block (does full search over neighbors)
    block <- data.frame(target = c(0, 1, 9, 9, -3, -3, -1), 
                        x =      c(0, 2, 5, 5, -5, -5, -2))
    lib <- c(2, NROW(block))
    pred <- c(1, 1)
    out <- block_lnlp(block, lib = lib, pred = pred, 
                      tp = 0, columns = 2, target_column = 1, 
                      num_neighbors = 2, stats_only = FALSE)
    
    expect_equal(out$model_output[[1]]$obs, out$model_output[[1]]$pred)
   
    ## larger block (does incremental search for neighbors)
    block <- data.frame(target = c(0, 1, rep(c(9, -3), each = 100), -1), 
                        x =      c(0, 2, rep(c(5, -5), each = 100), -2))
    lib <- c(2, NROW(block))
    pred <- c(1, 1)
    out <- block_lnlp(block, lib = lib, pred = pred, 
                      tp = 0, columns = 2, target_column = 1, 
                      num_neighbors = 2, stats_only = FALSE)
    
    expect_equal(out$model_output[[1]]$obs, out$model_output[[1]]$pred)
})

testthat::test_that("Simplex includes ties in nearest neighbors correctly", {
    ## smallish block (does full search over neighbors)
    block <- data.frame(target = c(0, 1, 9, 9, -3, -3, -1), 
                        x =      c(0, 2, 5, 5, -5, -5, -2))
    lib <- c(2, NROW(block))
    pred <- c(1, 1)
    out <- block_lnlp(block, lib = lib, pred = pred, 
                      tp = 0, columns = 2, target_column = 1, 
                      num_neighbors = 3, stats_only = FALSE)
    nn <- seq(from = 2, to = NROW(block))
    neighbor_dist <- abs(block[nn, 2])
    weights <- exp(-neighbor_dist / min(neighbor_dist))
    weights[neighbor_dist > 2] <- weights[neighbor_dist > 2] / 
        length(weights[neighbor_dist > 2])
    est <- sum(weights * block[nn, 1]) / sum(weights)
    expect_equal(est, out$model_output[[1]]$pred)
    
    ## larger block (does incremental search for neighbors)
    block <- data.frame(target = c(0, 1, rep(c(9, -4), each = 100), -1), 
                        x =      c(0, 2, rep(c(5, -5), each = 100), -2))
    lib <- c(2, NROW(block))
    pred <- c(1, 1)
    out <- block_lnlp(block, lib = lib, pred = pred, 
                      tp = 0, columns = 2, target_column = 1, 
                      num_neighbors = 4, stats_only = FALSE)
    
    nn <- seq(from = 2, to = NROW(block))
    neighbor_dist <- abs(block[nn, 2])
    weights <- exp(-neighbor_dist / min(neighbor_dist))
    weights[neighbor_dist > 2] <- 2 * weights[neighbor_dist > 2] / 
        length(weights[neighbor_dist > 2])
    est <- sum(weights * block[nn, 1]) / sum(weights)
    expect_equal(est, out$model_output[[1]]$pred)
})
ha0ye/rEDM documentation built on March 30, 2021, 11:21 p.m.