library(rstan) library(tidyverse) library(cowplot) theme_set(theme_cowplot()) options(mc.cores = parallel::detectCores()) rstan_options(auto_write = TRUE) devtools::load_all() cache_dir <- here::here("local_temp_data") if(!dir.exists(cache_dir)) { dir.create(cache_dir) }
disease_levels <- c("AA", "LoO2", "HiO2", "NIPPV", "MV", "ECMO") breathing_levels <- c("Discharge", disease_levels, "Death") data_grein <- readxl::read_excel(here::here("public_data", "Grein_et_al.xlsx"), na = c("")) %>% select(-Remdesivir_Last_Day) %>% pivot_longer(starts_with("Day_"), names_prefix = "Day_", names_to = "day", values_to = "breathing") %>% filter(!is.na(breathing)) %>% transmute( patient = Patient, breathing = breathing, day = as.integer(day)) %>% mutate( breathing = factor(breathing, levels = breathing_levels, ordered = TRUE), ) series_data <- data_grein %>% transmute(.serie = as.integer(patient), .time = day + 1, .observed = breathing) data_grein %>% ggplot(aes(x = day, y = patient, fill = breathing)) + geom_tile() data_grein %>% group_by(breathing) %>% summarise(pocet = n()) mean(data_grein$breathing == "LoO2", na.rm = TRUE)
m1_hidden_state_data <- tibble(id = factor(breathing_levels, levels = breathing_levels)) %>% mutate( corresponding_obs = id) m1_observed_state_data <- tibble(id = factor(breathing_levels, levels = breathing_levels), is_noisy = FALSE) recode_group <- function(rate_data, group_name) { is_column <- paste0("is_", group_name) id_column <- paste0(group_name, "_id") rate_data %>% mutate( !!is_column := as.double(rate_group == group_name), !!id_column := factor( if_else(rate_group == group_name, as.integer(.rate_id) - min(as.integer(.rate_id[rate_group == group_name])), as.integer(0)), levels = 0:(sum(rate_group == group_name) - 1), ordered = TRUE) ) } m1_rate_data <- rbind( tibble(.from = disease_levels, .to = "Death", rate_group = "death"), tibble(.from = "AA", .to = "Discharge", rate_group = "improve_one"), tibble(.from = disease_levels[2 : length(disease_levels)], .to = disease_levels[1 : (length(disease_levels) - 1)], rate_group = "improve_one"), tibble(.from = disease_levels[1 : (length(disease_levels) - 1)], .to = disease_levels[2 : length(disease_levels)], rate_group = "worsen_one") ) %>% mutate(.rate_id = factor(1:n()), .from = factor(.from, levels = breathing_levels), .to = factor(.to, levels = breathing_levels) ) %>% recode_group("death") %>% recode_group("improve_one") %>% recode_group("worsen_one") m1_initial_states <- series_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) m1_prior = brms::set_prior("normal(-2, 5)", "Intercept") + brms::set_prior("normal(0, 5)", "b") + brms::set_prior("normal(0, 2)", "sd") m1_formula <- ~ is_improve_one + is_death + mo(death_id) + (0 + is_worsen_one | worsen_one_id) + (0 + is_improve_one | improve_one_id) m1 <- brmshmmdata(m1_formula, series_data, m1_rate_data, m1_hidden_state_data, m1_initial_states, prior = m1_prior, observed_state_data = m1_observed_state_data) make_stancode_hmm(m1) #xx <- brms::make_standata(m1$formula, data = make_data_hmm(m1)$brmsdata) m1_fit <- brmhmm(m1, cache_file = paste0(cache_dir,"/grein_m1.rds")) summary(m1_fit$brmsfit)
m1_epred_rect <- posterior_epred_rect(m1_fit) m1_predicted_rect <- posterior_epred_to_predicted(m1_fit, m1_epred_rect) m1_predicted <- posterior_rect_to_long(m1_fit, m1_predicted_rect)
m1_predicted_df <- posterior_long_to_df(m1_fit$data, m1_predicted) step_size <- 6 for(step in 1:ceiling(m1_fit$data_processed$standata$N_series / step_size)) { series_id <- ((step - 1) * step_size + 1) : (step* step_size) posterior_state_plot(m1_predicted_df, series_id, m1_fit$data) %>% print() }
pp_check_last_state(m1_fit, m1_predicted_rect) pp_check_transitions_direction(m1_fit, m1_predicted_rect)
pp_check_transitions(m1_fit, m1_predicted_rect, states_from = 2:4) pp_check_transitions(m1_fit, m1_predicted_rect, states_from = 5:7)
m2_hidden_state_data <- rbind( tibble(base = disease_levels) %>% mutate(corresponding_obs = base) %>% crossing(tibble(state_group = c("_improving","_worsening"))) %>% mutate(id = factor(paste0(base, state_group))) %>% select(-base, -state_group), tibble(id = c("Death","Discharge")) %>% mutate(corresponding_obs = id) ) m2_observed_state_data <- m1_observed_state_data m2_rate_data <- rbind( tibble(.from = paste0(disease_levels, "_worsening"), .to = "Death", rate_group = "death"), tibble(.from = "AA_improving", .to = "Discharge", rate_group = "improve_one"), tibble(.from = paste0(disease_levels[2 : length(disease_levels)], "_improving"), .to = paste0(disease_levels[1 : (length(disease_levels) - 1)],"_improving"), rate_group = "improve_one"), tibble(.from = paste0(disease_levels[1 : (length(disease_levels) - 1)],"_worsening"), .to = paste0(disease_levels[2 : length(disease_levels)], "_worsening"), rate_group = "worsen_one"), tibble(.from = paste0(disease_levels,"_worsening"), .to = paste0(disease_levels, "_improving"), rate_group = "to_improving"), tibble(.from = paste0(disease_levels,"_improving"), .to = paste0(disease_levels, "_worsening"), rate_group = "to_worsening") ) %>% mutate(.rate_id = factor(1:n())) %>% recode_group("death") %>% recode_group("improve_one") %>% recode_group("worsen_one") %>% recode_group("to_improving") %>% recode_group("to_worsening") m2_initial_states <- series_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) %>% paste0(., "_worsening") m2_prior = m1_prior m2_formula <- update.formula(m1_formula, ~ . + is_to_improving + is_to_worsening + (0 + is_to_improving | to_improving_id) + (0 + is_to_worsening | to_worsening_id)) m2 <- brmshmmdata(m2_formula, series_data, m2_rate_data, m2_hidden_state_data, m2_initial_states, prior = m2_prior, observed_state_data = m1_observed_state_data) m2_fit <- brmhmm(m2, cache_file = paste0(cache_dir,"/grein_m2.rds"), control = list(adapt_delta = 0.9)) summary(m2_fit$brmsfit)
m2_epred_rect <- posterior_epred_rect(m2_fit) m2_predicted_rect <- posterior_epred_to_predicted(m2_fit, m2_epred_rect) m2_predicted <- posterior_rect_to_long(m2_fit, m2_predicted_rect)
m2_predicted_df <- posterior_long_to_df(m2_fit$data, m2_predicted) step_size <- 6 for(step in 1:ceiling(m2_fit$data_processed$standata$N_series / step_size)) { series_id <- ((step - 1) * step_size + 1) : (step* step_size) posterior_state_plot(m2_predicted_df, series_id, m2_fit$data) %>% print() }
pp_check_last_state(m2_fit, m2_predicted_rect) pp_check_transitions_direction(m2_fit, m2_predicted_rect) pp_check_transitions_direction(m2_fit, m2_predicted_rect, states_from = 2:7)
pp_check_transitions(m2_fit, m2_predicted_rect, states_from = 2:4) pp_check_transitions(m2_fit, m2_predicted_rect, states_from = 5:7)
m3_hidden_state_data <- m1_hidden_state_data m3_observed_state_data <- tibble(id = breathing_levels, is_noisy = c(FALSE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, FALSE), other_obs_1 = c(NA, disease_levels[2:6], disease_levels[1], NA), other_obs_2 = c(NA, disease_levels[3:6], disease_levels[1:2], NA), other_obs_3 = c(NA, disease_levels[4:6], disease_levels[1:3], NA), other_obs_4 = c(NA, disease_levels[5:6], disease_levels[1:4], NA), other_obs_5 = c(NA, disease_levels[6], disease_levels[1:5], NA) ) %>% mutate(id = factor(id, levels = breathing_levels)) m3_rate_data <- m1_rate_data m3_initial_states <- m1_initial_states m3_prior = m1_prior m3_formula <- m1_formula m3 <- brmshmmdata(m3_formula, series_data, m3_rate_data, m3_hidden_state_data, m3_initial_states, prior = m3_prior, observed_state_data = m3_observed_state_data) m3_fit <- brmhmm(m3, cache_file = paste0(cache_dir,"/grein_m3.rds"), control = list(adapt_delta = 0.9)) summary(m3_fit) summary(fit$brmsfit$fit, pars = c("sensitivity", "other_observations_probs"), probs = c(0.025,0.975))$summary
m3_epred_rect <- posterior_epred_rect(m3_fit) m3_predicted_rect <- posterior_epred_to_predicted(m3_fit, m3_epred_rect) m3_predicted <- posterior_rect_to_long(m3_fit, m3_predicted_rect)
m3_predicted_df <- posterior_long_to_df(m3_fit$data, m3_predicted) step_size <- 6 for(step in 1:ceiling(m3_fit$data_processed$standata$N_series / step_size)) { series_id <- ((step - 1) * step_size + 1) : (step* step_size) posterior_state_plot(m3_predicted_df, series_id, m3_fit$data) %>% print() }
pp_check_last_state(m3_fit, m3_predicted_rect) pp_check_transitions_direction(m3_fit, m3_predicted_rect) pp_check_transitions_direction(m3_fit, m3_predicted_rect, states_from = 2:7)
pp_check_transitions(m3_fit, m3_predicted_rect, states_from = 2:4) pp_check_transitions(m3_fit, m3_predicted_rect, states_from = 5:7)
m4_hidden_state_data <- rbind( tibble(base = disease_levels) %>% mutate(corresponding_obs = base) %>% crossing(tibble(state_group = c("_improving", "_stable", "_worsening"))) %>% mutate(id = factor(paste0(base, state_group))) %>% select(-base, -state_group), tibble(id = c("Death","Discharge")) %>% mutate(corresponding_obs = id) ) m4_observed_state_data <- m1_observed_state_data m4_rate_data <- rbind( tibble(.from = paste0(disease_levels, "_worsening"), .to = "Death", rate_group = "death"), tibble(.from = "AA_improving", .to = "Discharge", rate_group = "improve_one"), tibble(.from = paste0(disease_levels[2 : length(disease_levels)], "_improving"), .to = paste0(disease_levels[1 : (length(disease_levels) - 1)],"_improving"), rate_group = "improve_one"), tibble(.from = paste0(disease_levels[1 : (length(disease_levels) - 1)],"_worsening"), .to = paste0(disease_levels[2 : length(disease_levels)], "_worsening"), rate_group = "worsen_one"), tibble(.from = paste0(disease_levels,"_stable"), .to = paste0(disease_levels, "_improving"), rate_group = "to_improving"), tibble(.from = paste0(disease_levels,"_stable"), .to = paste0(disease_levels, "_worsening"), rate_group = "to_worsening"), tibble(.from = paste0(disease_levels,"_worsening"), .to = paste0(disease_levels, "_stable"), rate_group = "worsening_to_stable"), tibble(.from = paste0(disease_levels,"_improving"), .to = paste0(disease_levels, "_stable"), rate_group = "improving_to_stable"), tibble(.from = paste0(disease_levels,"_improving"), .to = paste0(disease_levels, "_worsening"), rate_group = "improving_to_worsening"), tibble(.from = paste0(disease_levels,"_worsening"), .to = paste0(disease_levels, "_improving"), rate_group = "worsening_to_improving") ) %>% mutate(.rate_id = factor(1:n())) %>% recode_group("death") %>% recode_group("improve_one") %>% recode_group("worsen_one") %>% recode_group("to_improving") %>% recode_group("to_worsening") %>% recode_group("improving_to_stable") %>% recode_group("worsening_to_stable") %>% recode_group("improving_to_worsening") %>% recode_group("worsening_to_improving") m4_initial_states <- series_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) %>% paste0(., "_worsening") m4_prior = m1_prior m4_formula <- update.formula(m2_formula, ~ . + is_improving_to_stable + is_worsening_to_stable + is_improving_to_worsening + is_worsening_to_improving + (0 + is_improving_to_stable | improving_to_stable_id) + (0 + is_worsening_to_stable | worsening_to_stable_id) + (0 + is_improving_to_worsening | worsening_to_improving_id) + (0 + is_worsening_to_improving | improving_to_worsening_id) ) m4 <- brmshmmdata(m4_formula, series_data, m4_rate_data, m4_hidden_state_data, m4_initial_states, prior = m4_prior, observed_state_data = m1_observed_state_data) m4_fit <- brmhmm(m4, cache_file = paste0(cache_dir,"/grein_m4.rds"), control = list(adapt_delta = 0.9), cores = 1) summary(m4_fit$brmsfit)
# m4_hidden_state_data <- rbind( # tibble(base = disease_levels) %>% mutate(corresponding_obs = base) %>% # crossing(tibble(state_group = c("_improving", "_stable", "_worsening"))) %>% # mutate(id = factor(paste0(base, state_group))) %>% # select(-base, -state_group), # tibble(id = c("Death","Discharge")) %>% mutate(corresponding_obs = id) # ) # # m4_observed_state_data <- m1_observed_state_data # # m4_rate_data <- rbind( # tibble(.from = paste0(disease_levels, "_worsening"), .to = "Death", rate_group = "death"), # tibble(.from = "AA_improving", .to = "Discharge", rate_group = "improve_one"), # tibble(.from = paste0(disease_levels[2 : length(disease_levels)], "_improving"), # .to = paste0(disease_levels[1 : (length(disease_levels) - 1)],"_improving"), # rate_group = "improve_one"), # tibble(.from = paste0(disease_levels[1 : (length(disease_levels) - 1)],"_worsening"), # .to = paste0(disease_levels[2 : length(disease_levels)], "_worsening"), # rate_group = "worsen_one"), # tibble(.from = paste0(disease_levels,"_stable"), # .to = paste0(disease_levels, "_improving"), # rate_group = "to_improving"), # tibble(.from = paste0(disease_levels,"_stable"), # .to = paste0(disease_levels, "_worsening"), # rate_group = "to_worsening"), # tibble(.from = paste0(disease_levels,"_worsening"), # .to = paste0(disease_levels, "_stable"), # rate_group = "worsening_to_stable"), # tibble(.from = paste0(disease_levels,"_improving"), # .to = paste0(disease_levels, "_stable"), # rate_group = "improving_to_stable"), # tibble(.from = paste0(disease_levels,"_improving"), # .to = paste0(disease_levels, "_worsening"), # rate_group = "improving_to_worsening"), # tibble(.from = paste0(disease_levels,"_worsening"), # .to = paste0(disease_levels, "_improving"), # rate_group = "worsening_to_improving") # ) %>% # mutate(.rate_id = factor(1:n())) %>% # recode_group("death") %>% # recode_group("improve_one") %>% # recode_group("worsen_one") %>% # recode_group("to_improving") %>% # recode_group("to_worsening") %>% # recode_group("improving_to_stable") %>% # recode_group("worsening_to_stable") # # # m4_initial_states <- series_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) %>% paste0(., "_worsening") # # m4_prior = m1_prior # # m4_formula <- update.formula(m2_formula, # ~ . + is_improving_to_stable + is_worsening_to_stable + (0 + is_improving_to_stable | improving_to_stable_id) + # (0 + is_worsening_to_stable | worsening_to_stable_id)) # # m4 <- brmshmmdata(m4_formula, series_data, # m4_rate_data, m4_hidden_state_data, m4_initial_states, prior = m4_prior, # observed_state_data = m1_observed_state_data) # # m4_fit <- brmhmm(m4, cache_file = paste0(cache_dir,"/grein_m4.rds"), control = list(adapt_delta = 0.9)) # # summary(m4_fit$brmsfit)
m4_epred_rect <- posterior_epred_rect(m4_fit) m4_predicted_rect <- posterior_epred_to_predicted(m4_fit, m4_epred_rect) m4_predicted <- posterior_rect_to_long(m4_fit, m4_predicted_rect)
m4_predicted_df <- posterior_long_to_df(m4_fit$data, m4_predicted) step_size <- 6 for(step in 1:ceiling(m4_fit$data_processed$standata$N_series / step_size)) { series_id <- ((step - 1) * step_size + 1) : (step* step_size) posterior_state_plot(m4_predicted_df, series_id, m4_fit$data) %>% print() }
pp_check_last_state(m4_fit, m4_predicted_rect) pp_check_transitions_direction(m4_fit, m4_predicted_rect) pp_check_transitions_direction(m4_fit, m4_predicted_rect, states_from = 2:7)
pp_check_transitions(m4_fit, m4_predicted_rect, states_from = 2:4) pp_check_transitions(m4_fit, m4_predicted_rect, states_from = 5:7)
disease_levels_m5 <- c("AA", "O2", "MV", "ECMO") breathing_levels_m5 <- c("Discharge", disease_levels_m5, "Death") m5_observed_state_data <- tibble(id = factor(breathing_levels_m5, levels = breathing_levels_m5), is_noisy = FALSE) m5_hidden_state_data <- m4_hidden_state_data %>% filter(corresponding_obs != "NIPPV") %>% droplevels() m5_hidden_state_data <- rbind( tibble(base = disease_levels_m5) %>% mutate(corresponding_obs = base) %>% crossing(tibble(state_group = c("_improving", "_stable", "_worsening"))) %>% mutate(id = factor(paste0(base, state_group))) %>% select(-base, -state_group), tibble(id = c("Death","Discharge")) %>% mutate(corresponding_obs = id) ) m5_rate_data <- rbind( tibble(.from = paste0(disease_levels_m5, "_worsening"), .to = "Death", rate_group = "death"), tibble(.from = "AA_improving", .to = "Discharge", rate_group = "improve_one"), tibble(.from = paste0(disease_levels_m5[2 : length(disease_levels_m5)], "_improving"), .to = paste0(disease_levels_m5[1 : (length(disease_levels_m5) - 1)],"_improving"), rate_group = "improve_one"), tibble(.from = paste0(disease_levels_m5[1 : (length(disease_levels_m5) - 1)],"_worsening"), .to = paste0(disease_levels_m5[2 : length(disease_levels_m5)], "_worsening"), rate_group = "worsen_one"), tibble(.from = paste0(disease_levels_m5,"_stable"), .to = paste0(disease_levels_m5, "_improving"), rate_group = "to_improving"), tibble(.from = paste0(disease_levels_m5,"_stable"), .to = paste0(disease_levels_m5, "_worsening"), rate_group = "to_worsening"), tibble(.from = paste0(disease_levels_m5,"_worsening"), .to = paste0(disease_levels_m5, "_stable"), rate_group = "worsening_to_stable"), tibble(.from = paste0(disease_levels_m5,"_improving"), .to = paste0(disease_levels_m5, "_stable"), rate_group = "improving_to_stable") ) %>% mutate(.rate_id = factor(1:n())) %>% recode_group("death") %>% recode_group("improve_one") %>% recode_group("worsen_one") %>% recode_group("to_improving") %>% recode_group("to_worsening") %>% recode_group("improving_to_stable") %>% recode_group("worsening_to_stable") series_data_m5 <- series_data %>% mutate(.observed = factor(case_when( .observed == "NIPPV" ~ "MV", .observed %in% c("LoO2","HiO2") ~ "O2", TRUE ~ as.character(.observed)), levels = breathing_levels_m5, ordered = TRUE)) m5_initial_states <- series_data_m5 %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) %>% paste0(., "_worsening") m5_prior <- brms::set_prior("normal(-2, 5)", "Intercept") + brms::set_prior("normal(0, 5)", "b") m5_formula <- ~ factor(.rate_id) #m4_formula m5 <- brmshmmdata(m5_formula, series_data_m5, m5_rate_data, m5_hidden_state_data, m5_initial_states, prior = m5_prior, observed_state_data = m5_observed_state_data) m5_fit <- brmhmm(m5, cache_file = paste0(cache_dir,"/grein_m5.rds"), control = list(adapt_delta = 0.9)) summary(m5_fit$brmsfit)
m5_epred_rect <- posterior_epred_rect(m5_fit) m5_predicted_rect <- posterior_epred_to_predicted(m5_fit, m5_epred_rect) m5_predicted <- posterior_rect_to_long(m5_fit, m5_predicted_rect)
m5_predicted_df <- posterior_long_to_df(m5_fit$data, m5_predicted) step_size <- 6 for(step in 1:ceiling(m5_fit$data_processed$standata$N_series / step_size)) { series_id <- ((step - 1) * step_size + 1) : (step* step_size) posterior_state_plot(m5_predicted_df, series_id, m5_fit$data) %>% print() }
pp_check_last_state(m5_fit, m5_predicted_rect) pp_check_transitions_direction(m5_fit, m5_predicted_rect) pp_check_transitions_direction(m5_fit, m5_predicted_rect, states_from = 2:6)
pp_check_transitions(m5_fit, m5_predicted_rect, states_from = 2:4) pp_check_transitions(m5_fit, m5_predicted_rect, states_from = 5:6)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.