context("test-survival.R -- Pooled hazards model")
options(sl3.verbose = FALSE)
library(data.table)
library(origami)
g0 <- function(W) {
W1 <- W[, 1]
W2 <- W[, 2]
W3 <- W[, 3]
W4 <- W[, 4]
# rep(0.5, nrow(W))
scale_factor <- 0.8
A1 <- plogis(scale_factor * W1)
A2 <- plogis(scale_factor * W2)
A3 <- plogis(scale_factor * W3)
A <- cbind(A1, A2, A3)
# make sure A sums to 1
A <- normalize_rows(A)
}
gen_data <- function(n = 1000, p = 4) {
W <- matrix(rnorm(n * p), nrow = n)
colnames(W) <- paste("W", seq_len(p), sep = "")
g0W <- g0(W)
A <- factor(apply(g0W, 1, function(pAi) which(rmultinom(1, 1, pAi) == 1)))
A_vals <- levels(A)
df <- data.table(W, A)
g0W <- g0(W)
return(list(data = df, truth = g0W))
}
set.seed(1234)
sim_results <- gen_data(1000)
dat <- sim_results$data
dat <- dat[, A := as.numeric(A)]
g0W <- sim_results$truth
Wnodes <- grep("^W", names(dat), value = TRUE)
Anode <- "A"
task <- sl3_Task$new(dat, covariates = Wnodes, outcome = Anode)
hazards_task <- pooled_hazard_task(task)
lrnr_ph <- make_learner(Lrnr_pooled_hazards,
binomial_learner = make_learner(Lrnr_xgboost)
)
fit <- lrnr_ph$train(task)
preds <- unpack_predictions(fit$base_predict())
p0 <- g0(as.matrix(task$X))
mean(rowSums(p0 * log(preds)))
lrnr_xgb <- make_learner(Lrnr_xgboost)
xgb_fit <- lrnr_xgb$train(task)
xgb_preds <- unpack_predictions(xgb_fit$predict(task))
mean(rowSums(p0 * log(xgb_preds)))
lrnr_mean <- make_learner(Lrnr_mean)
mean_fit <- lrnr_mean$train(task)
mean_preds <- unpack_predictions(mean_fit$predict(task))
mean(rowSums(p0 * log(mean_preds)))
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.