library(rstan)
library(tidyverse)
library(cowplot)
theme_set(theme_cowplot())

options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

devtools::load_all()
data <- hmm_simulator(10, 20, 4, use_noisy_states = TRUE)

improving_data <- data.frame(day = data$observed$serie_data$.time, patient = data$observed$serie_data$.serie, improving = data$true$true_improving) %>% filter(improving)

data.frame(day = data$observed$serie_data$.time, patient = data$observed$serie_data$.serie, breathing = data$true$true_base_states) %>%
  mutate(breathing = factor(breathing, 1:nrow(data$observed$observed_state_data), ordered = TRUE)) %>%
  ggplot(aes(x = day, y = patient, fill = breathing)) + geom_tile(width = 1, height = 0.5) + geom_tile(data = improving_data, width = 1, height = 0.1, fill = "gray") + ggtitle("True - gray lines for improving")


data.frame(day = data$observed$serie_data$.time, patient = data$observed$serie_data$.serie, breathing = data$observed$serie_data$.observed) %>%
  mutate(breathing = factor(breathing, 1:nrow(data$observed$observed_state_data), ordered = TRUE)) %>%
  ggplot(aes(x = day, y = patient, fill = breathing)) + geom_tile(width = 1, height = 0.5) + ggtitle("Observed")
adapt_delta <- 0.9

data <- hmm_simulator(N_series = 10, N_time = 12, N_mid_states = 4, use_noisy_states = FALSE)
standata <- make_standata_hmm(data$observed)


#Temporary hack
registerS3method("stan_log_lik", class = "rate_hmm", method = stan_log_lik.rate_hmm, envir = asNamespace("brms"))

single_code <- make_stancode_hmm(data$observed)

model_file <- here::here("local_temp_data", "devel_single_model.stan")
write_lines(single_code, path = model_file)
model_single <- cmdstanr::cmdstan_model(model_file)




if(inherits(model_single, "stanmodel")) {
  fit <- sampling(model_single, data = standata, control = list(adapt_delta = adapt_delta))
} else {
  cmdstan_fit <- model_single$sample(data = standata, adapt_delta = adapt_delta)
  fit <- rstan::read_stan_csv(cmdstan_fit$output_files())
}
evaluation_summary(fit, data$true)
data <- hmm_simulator(20, 15, 2, use_noisy_states = TRUE, N_treatments = 0)


bfit <- brmhmm(data$observed)

summary(bfit)


bfit$brmsfit$fit
#xx <- brms::posterior_epred(bfit) %>% hidden_to_corresponding_observed(bfit$data_processed, .)
xx <- brms::posterior_predict(bfit)


xy <- prediction_to_wide_format(bfit$data, xx)


step_size <- 6
for(step in 1:ceiling(bfit$data_processed$standata$N_series / step_size)) {
  series_id <- ((step - 1) * step_size + 1) : (step* step_size)

  posterior_state_plot(xy, series_id, bfit$data) %>% print()

}
noisy_states_in_sbc <- TRUE
sbc_generator <- function() {
  data <- hmm_simulator(40, 25, 2, use_noisy_states = noisy_states_in_sbc)
  standata <- make_standata_hmm(data$observed)

  list(observed = standata,
       true = data$true,
       observed_raw = data$observed)
}

model_sbc <-  stan_model(model_code = make_stancode_hmm(sbc_generator()$observed_raw))

sbc_res <- sbc(model_sbc, generator = sbc_generator, N_steps = 50, control = list(adapt_delta = adapt_delta))

saveRDS(sbc_res, "sbc.rds")

sbc_res$params %>% filter(grepl("b|sd", param_name)) %>% plot_sbc_params()
sbc_res$params %>% filter(grepl("r_1", param_name)) %>% plot_sbc_params()
sbc_res$params %>% filter(grepl("r_2", param_name)) %>% plot_sbc_params()
if(noisy_states_in_sbc) {
  sbc_res$params %>% filter(grepl("sensitivity", param_name)) %>% plot_sbc_params()
  sbc_res$params %>% filter(grepl("other_observations_probs", param_name)) %>% plot_sbc_params()
}

#sbc_res$params %>% plot_sbc_params()
summarise_sbc_diagnostics(sbc_res)
sbc_res$diagnostics


cas-bioinf/covid19retrospective documentation built on Sept. 7, 2021, 6:19 p.m.