tests/testthat/test-survregVB.fit.R

library(survival)

# create X matrix
X <- matrix(c(rep(1, 300), simulation_nofrailty$x1,
              simulation_nofrailty$x2), nrow = 300)

# create Surv objects
Y <- Surv(simulation_nofrailty$Time, simulation_nofrailty$delta)
Y.10 <- Surv(simulation_nofrailty$Time.10, simulation_nofrailty$delta.10)
Y.30 <- Surv(simulation_nofrailty$Time.30, simulation_nofrailty$delta.30)

# priors, use non-informative priors
mu_0 <- c(0, 0, 0)
v_0 <- 0.1
alpha_0 <- 11
omega_0 <- 10

test_that("survregVB.fit", {
  result <- survregVB.fit(Y, X, alpha_0, omega_0, mu_0, v_0,
    max_iteration = 100, threshold = 0.01
  )
  expected <- list(
    ELBO = -1723.493128,
    alpha = 311,
    omega = 250.8086135,
    mu = c(0.05739213186, 0.5273365895, 0.8418679603),
    Sigma = matrix(
      c(
        0.18803590234, -0.1735466575414, -0.0134736286120,
        -0.17354665754, 0.1726178469305, -0.0000653350561,
        -0.01347362861, -0.0000653350561, 0.0254160913341
      ),
      nrow = 3, ncol = 3, byrow = FALSE,
      dimnames = list(NULL, NULL)
    ),
    iterations = 8,
    n = 300
  )
  expect_equal(result, expected)

  result <- survregVB.fit(Y.10, X, alpha_0, omega_0,
    mu_0, v_0,
    max_iteration = 100,
    threshold = 0.01
  )
  expected <- list(
    ELBO = -1509.54707,
    alpha = 279,
    omega = 220.2464816,
    mu = c(-0.02694633027, 0.5870221907, 0.8883491399),
    Sigma = matrix(
      c(
        0.18618723857, -0.1720991407945, -0.0136226390896,
        -0.17209914079, 0.1713795066094, 0.0002720367882,
        -0.01362263909, 0.0002720367882, 0.0252914786949
      ),
      nrow = 3, ncol = 3, byrow = FALSE,
      dimnames = list(NULL, NULL)
    ),
    iterations = 8,
    n = 300
  )
  expect_equal(result, expected)

  result <- survregVB.fit(Y.30, X, alpha_0, omega_0,
    mu_0, v_0,
    max_iteration = 100,
    threshold = 0.01
  )
  expected <- list(
    ELBO = -1149.432804,
    alpha = 220,
    omega = 181.5454727,
    mu = c(0.1478980612, 0.4472516245, 0.8735885567),
    Sigma = matrix(
      c(
        0.22110740073, -0.205050440078, -0.016367876916,
        -0.20505044008, 0.203811682106, 0.001569021946,
        -0.01636787692, 0.001569021946, 0.029198756794
      ),
      nrow = 3, ncol = 3, byrow = FALSE,
      dimnames = list(NULL, NULL)
    ),
    iterations = 8,
    n = 300
  )
  expect_equal(result, expected)
})

test_that("unconverged", {
  expect_warning(
    survregVB.fit(Y, X, alpha_0, omega_0, mu_0, v_0,
      max_iteration = 6, threshold = 0.01
    ),
    "The max iteration has been achieved and the algorithm has not converged"
  )
})

Try the survregVB package in your browser

Any scripts or data that you put into this service are public.

survregVB documentation built on June 8, 2025, 1:46 p.m.