Nothing
fit_m_function <- function(policy_data,
m_model,
full_history = FALSE) {
K <- get_K(policy_data)
# input checks:
if (!is.null(m_model)){
mes <- "m_model must be a single q_model."
if (!inherits(m_model, "q_model")) {
stop(mes)
}
}
## check if missing final outcomes occur (at stage K+1)
missing <- get_element(policy_data, "cens_indicator")
indicator <- stage <- NULL
missing <- missing[get("stage") == (K+1), list(get("indicator"))]
missing <- unlist(missing)
if (missing == FALSE){
m_function <- NULL
} else {
if (is.null(m_model)) {
stop("right-censoring/missing (final) outcome occur at stage K+1, please provide an m_model.")
}
## getting the IDs:
id <- get_id(policy_data)
## getting the history at stage K+1:
his <- get_history(policy_data,
stage = K+1,
full_history = full_history,
event_set = c(0,1,2))
H <- get_H(his)[get_event(his) == 1, ]
id_not_missing <- get_id(his)[get_event(his) == 1][["id"]]
if(length(id_not_missing) == 0) {
stop("Unable to fit m_model: all utility outcomes are missing")
}
## getting the observed (complete) utility:
utility <- get_utility(policy_data)
## vector with non-missing entries U_i:
U <- utility$U[(id %in% id_not_missing)]
stopifnot(any(!is.na(U)))
stopifnot(length(U) == nrow(H))
## getting the historic rewards/utility contributions
U_bar <- get_element(his, "U")[get("id") %in% id_not_missing, ][["U_bar"]]
## calculating the residual utility:
V_res <- U - U_bar
## fitting the m-model:
tryCatch({
m_model <- m_model(AH = H, V_res = V_res)
}, error = function(e) {
stop("Error fitting m_model: ", e$message)
}, warning = function(w) {
warning("Warning in m_model fitting: ", w$message)
## Continue with the model despite warning
})
## setting S3 class and attributes:
m_function <- list(
m_model = m_model,
H_names = colnames(H),
stage = K+1
)
class(m_function) <- "m_function"
attr(m_function, "full_history") <- full_history
}
return(m_function)
}
#' @export
predict.m_function <- function(object, new_policy_data, ...) {
K <- get_K(new_policy_data)
## check:
if (get_element(object, "stage") != K+1){
stop("The number of stages in m_function and new_policy_data does not match.")
}
## getting the full history attribute:
full_history <- attr(object, "full_history")
## creating the event history object:
new_history <- get_history(new_policy_data,
full_history = full_history,
stage = K+1,
event_set = c(0,1,2))
m_model <- getElement(object, "m_model")
## H_names <- getElement(object, "H_names")
id_stage <- get_id_stage(new_history)
new_H <- get_H(new_history)
## setting up output
q_values <- id_stage
## getting residual predictions
res_pred <- predict(m_model, new_AH = new_H)
## getting historic rewards
U_bar <- get_element(new_history, "U")[["U_bar"]]
set(q_values, j = "Q", value = U_bar + res_pred)
return(q_values)
}
Any scripts or data that you put into this service are public.
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.