Hidden Markov Models

knitr::opts_chunk$set(echo = FALSE)

devtools::load_all()
library(tidyverse)
options(mc.cores = parallel::detectCores())
rstan::rstan_options(auto_write = TRUE)
options(brms.backend = "cmdstanr")

theme_set(cowplot::theme_cowplot())
data <- read_data_for_analysis()

if(length(unique(data$patient_data$patient_id)) != length(data$patient_data$patient_id)) {
  stop("FAil")
}

wide <- data$marker_data_wide


cache_dir <- here::here("local_temp_data", "hmm")
if(!dir.exists(cache_dir)) {
  dir.create(cache_dir, recursive = TRUE)
}

do_pp_checks <- FALSE
hmm_hypothesis_res_list <- list()

cl <-  parallel::makePSOCKcluster(parallel::detectCores())

#Temporary hack
registerS3method("stan_log_lik", class = "rate_hmm", method = stan_log_lik.rate_hmm, envir = asNamespace("brms"))

The overall approach for hidden Markov models

To complement the more traditionally used proportional hazards models we also use several variants of hidden Markov models (HMMs). A discrete random process $X_1, X_2, X_3, ...$ satisfies the Markov property whenever for all states $x$ and times $t$ we have $P(X_{t + 1} = x | X_1 = x_1, X_2 = x_2, ..., X_t = x_t) = P(X_{t+1} = x | X_t = x_t)$, i.e. when the process is "memoryless". If this process is unobservable, but we can observe another process $Y$ such that $P(Y_{t} = y | X_1 = x_1, X_2 = x_2, ..., X_n = x_n) = P(Y_{t} = y | X_t = x_t)$ we call the pair $(X, Y)$ a hidden Markov process.

We treat the breathing support required as a Markov process, either with states directly corresponding to the breathing support and directly observable (Fig. \@ref(fig:hmmstates)a), or with the states carrying a binary improving/worsening component which is unobservable (Fig. \@ref(fig:hmmstates)b) while the breathing support dimension still being completely observable. The former completely observable model is called simple in the following text while the latter is called complex. This results in a very restricted observation matrix. The reason is that the breathing support actually used is very likely to correspond to the breathing support needed by the patient. In initial versions of our models we also tested models that treated the breathing support observed as a noisy realisation of the actual state, but the fitted probabilities of such imprecise observations were always very low, so we ended up not using them.

Directly modeling the full transition matrix of the Markov process would not make best use of the data as we know there is additional structure expressed in the ordering of states, e.g. that the probability of being discharged from the "AA" state is higher than from the "Oxygen" state, or that transitioning to "Ventilated" is more likely from the "Oxygen" state than from the "AA" state. To incorporate this structure, we follow the approach outlined in [@http://zotero.org/users/5567156/items/F9K2L5Q8].

We setup a rate matrix $R$ so that for any two states $i \neq j$ $R_{i,j}$ is the rate of transition from $i$ to $j$. $R$ is intended to be sparse, i.e. $R_{i,j} \neq 0$ only for transitions explicitly intended to be modeled directly. Additionally the diagonal elements are set as $R_{i,i} = -\sum_{j \neq i} R_{i,j}$.

knitr::include_graphics(here::here("manuscript","static_figs","model_states_hmm.png"), auto_pdf = TRUE)

The evolution of the vector $p(t)$ of the state probabilities in continuous time $t$ is then given by the differential equation

$$ \frac{dp(t)}{dt} = Rp $$

Given the initial state probabilities $p(0)$, the solution to this equation is:

$$ p(t) = \exp(tR)p(0) $$

where $\exp$ is the matrix exponential. We can thus compute a discrete-time transition matrix $S = \exp(R)$ so that $p(t + 1) = Sp(t)$. The appeal of this approach is that we can enforce a lot of structure on the transition matrix $S$ while allowing positive probability for almost all transitions. In the case of the complex model, the full transition matrix between the 8 states would have 42 free parameters (no transitions from the "Death" and "Discharged states) while the rate matrix has only 13 free parameters. We could obviously have a sparse transition matrix, but that would make the model overly rigid as transitions that are not modeled directly but actually occur in the data (e.g. from "Oxygen" to "Discharged") would have zero probability.

We also find the rate formulation to be theoretically appealing - the disease progression takes place in continuous time, the discretization into individual days is an artifact of the way we collected data, not the reality.

Finally we put a mixed-effects linear predictor on each of the rates we model so that they are allowed to differ between patients (indexed by $k$) and over (discrete) time (indexed by $t$) $R_{k,t;i,j} = \exp(\mu_{k,t;i,j})$ and $S_{k,t} = \exp(R_{k,t})$. We build the model in the Stan language, using the brms package to express the linear predictors $\mu_{k,t;i,j}$ - we use baseline patient characteristics as time-constant predictors and the treatments administered, initiation of best supportive care and markers as time-varying predictors.

As an additional structure, we consider rate groups which are: "Improving" (from a breathing level to a better one, including "Discharged"), "Worsening" (from a breathing level to a worse one, not including "Death"), and "Death" (any transition to the "Death" state). In addition, the complex model has "To improving" and "To worsening" which correspond to the switches between the improving and worsening variants of each breathing level.

Below we use three types of models: simple with predictors acting on rate groups (i.e. assuming that any predictor has the same multiplicative effect on all rates in a group), simple with predictors acting on rates directly (i.e. a predictor can have different effect on each rate) and complex with predictors acting on rate groups.

Models will be presented in a unified structure that we explain with the first model and will not be repeated

transform_serie_data <- function(raw_serie_data) {
  raw_serie_data  %>%
  mutate(.time = .time +1) %>% 
  filter(.time >= 1) %>%
  group_by(.serie) %>%
  filter(sum(!is.na(.observed)) > 1) %>%
  ungroup() %>%
  #Force explicit dummy coding for the models
  mutate(across(tidyselect::starts_with("took_"), as.integer),
         best_supportive_care = as.integer(best_supportive_care),
         is_male = as.integer(sex == "M")
         )
}

serie_data <- wide %>%
  inner_join(data$patient_data, by = c("hospital_id", "patient_id"), suffix = c("_admission", "")) %>% 
  mutate(.time = day, .serie = patient_id, .observed = breathing_s) %>%
  transform_serie_data()


serie_data_28 <- wide %>%
  right_join(crossing(data$patient_data, day = 0:28), by = c("hospital_id", "patient_id", "day"),  suffix = c("_admission", "")) %>%
  # group_by(patient_id) %>%
  # mutate(
  #   across(
  #     c(starts_with("took_"), all_of("best_supportive_care")),
  #     ~ if_else(is.na(.x) & day > max(day[!is.na(.x)]), .x[which.max(day[!is.na(.x)])], .x)
  #     ),
  # ) %>%
  # ungroup() %>%
  mutate(.time = day, .serie = patient_id, .observed = breathing_s) %>%
  transform_serie_data()

# Extend terminal states and took_XX data for serie_data_28
serie_data_28 <- serie_data_28 %>%
  group_by(.serie) %>%
  mutate(last_day = max(day[!is.na(.observed)]),
         last_state = .observed[day == last_day],
         .observed = if_else(last_state %in% c("Discharged", "Death") & is.na(.observed) & day > last_day,
                             last_state,
                             .observed)) %>%
  mutate(across(c(tidyselect::starts_with("took_"), all_of("best_supportive_care")), ~ if_else(is.na(.x), as.integer(any(.x != 0, na.rm = TRUE)), .x))) %>%
  ungroup()


# Check succesful extension
mismatches <- serie_data_28 %>% inner_join(serie_data, by = c("patient_id", "day")) %>%
  select(-starts_with("."), -last_day,  -last_state) %>%
  mutate(across(everything(), as.character)) %>%
  pivot_longer(!any_of(c("patient_id", "day")), names_to = c("col", "source"), names_sep = "\\.") %>%
  group_by(patient_id, day, col) %>%
  summarise(n_vals = length(unique(value)), .groups = "drop") %>%
  filter(n_vals > 1)

if(nrow(mismatches) > 0) {
  print(mismatches)
  stop("Mismatches in serie_data_28")
}


base_observed_state_data <- tibble(id = factor(disease_s_levels, levels = disease_s_levels), is_noisy = FALSE)

simple_hidden_state_data <- tibble(id = factor(disease_s_levels, levels = disease_s_levels)) %>% mutate( corresponding_obs = id)

simple_initial_states <- serie_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed)

default_control <- list(adapt_delta = 0.9)

Treatments - simple, effects on rate groups

simple_rate_data <- rbind(
  tibble(.from = c("Oxygen", "Ventilated"), .to = "Death", rate_group = "death"),
  tibble(.from = "AA", .to = "Discharged", rate_group = "improve_one"),
  tibble(.from = breathing_s_levels[2 : length(breathing_s_levels)], .to = breathing_s_levels[1 : (length(breathing_s_levels) - 1)], rate_group = "improve_one"),
  tibble(.from = breathing_s_levels[1 : (length(breathing_s_levels) - 1)], .to = breathing_s_levels[2 : length(breathing_s_levels)], rate_group = "worsen_one")
) %>%
  mutate(.rate_id = factor(paste0(.from, "_", .to), levels = paste0(.from, "_", .to)),
         .from = factor(.from, levels = disease_s_levels),
         .to = factor(.to, levels = disease_s_levels),
          rate_group = fct_relevel(rate_group, "improve_one")
       )

simple_initial_states <- serie_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed)

The rates included in the simple model are shown in Table \@ref(tab:ratessimple).

simple_rate_data  %>% knitr::kable(booktabs = TRUE, caption="Rates used in the simple model.")

Only treatments

trt_group_m1_hidden_state_data <- simple_hidden_state_data

trt_group_m1_observed_state_data <- base_observed_state_data


trt_group_m1_prior = #brms::set_prior("normal(-2, 5)", "Intercept") +
  brms::set_prior("normal(-2, 5)", "b") +
  brms::set_prior("normal(0, 2)", "sd")

trt_group_m1_formula <- ~ 0 + .rate_id +  (0 + best_supportive_care + took_hcq + took_az + took_favipiravir + took_convalescent_plasma || rate_group)

For this model, the $\log(R_{i,j})$ elements are modeled via the brms formula

trt_group_m1_formula

Our code is setup so that the formula can combine predictors drawn from the description of rates (see Table \@ref(tab:ratessimple)) and those taken from patient's data (both the fixed baseline characteristics and time-varying values). There is a separate coefficient for each rate and then the other predictors act uniformly on each rate group. The took_XXX variables are boolean predictors that are set to true for all days since the first dose of the treatment was given - i.e. we assume the treatment alters the overall disease progression irreversibly.

This is the summary of the fitted coefficients. Q2.5 and Q97.5 are the boundaries of the central 95% credible interval.

trt_group_m1 <- brmshmmdata(trt_group_m1_formula, serie_data, 
                  simple_rate_data, simple_hidden_state_data, simple_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = base_observed_state_data)

trt_group_m1_fit <- brmhmm(trt_group_m1, cache_file = paste0(cache_dir,"/trt_group_m1.rds"), control = default_control, init = 0.1, iter = 1500)

print_hmm_fit_summary(trt_group_m1_fit)

Note that the coefficients are hard to interpret directly. E.g. initiated best supportive care (best_supportive_care) increases the rate of "Death" but it also might decrease the rate of "worsening" transitions. How should those be combined? To make the results easy to interpret we use the model to build counterfactual predictions for each patient and treatment: what is (according to the model) the probability of the patient being alive at 28 days if they did not receive the treatment at all and what if they received the treatment immediately upon admission. Those two probabilities can then be used to compute log odds-ratio, which is the value we report. Similarly, we compute the log odds-ratio for being still hospitalized against being discharged. In both cases log odds ratio > 0 is worse for the patient (higher chance of death, higher chance of being hospitalized) in the treatment group.

Additionally we also look on the (counterfactual) time of hospitalization for patients discharged by day 28 in both groups and - to put it on the same scale, compare the mean of the log of the per-patient ratios of those times, i.e. estimate > 0 implies worse outcome (longer hospitalization) in the treatment group. This value is not reported in the main manuscript, only in the supplement.

For this model this results in the following estimates:

# 

hmm_hypothesis_res_list[["trt_group_m1"]] <- evaluate_all_treatment_hypotheses(trt_group_m1_fit, model_subgroup = "simple_group", adjusted ="all treatments, supportive", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_group_m1"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_group_m1_fit)
}
rm(trt_group_m1_fit)

Treatments + age + sex

In this model the formula is expanded to include age and sex:

trt_group_m2_formula <- update.formula(trt_group_m1_formula,  ~ . +  (0 + age_norm + is_male || rate_group))

trt_group_m2_formula

This is the summary of the model parameters

trt_group_m2 <- brmshmmdata(trt_group_m2_formula, serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, simple_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

trt_group_m2_fit <- brmhmm(trt_group_m2, cache_file = paste0(cache_dir,"/trt_group_m2.rds"), init = 0.1, control = list(adapt_delta = 0.95), iter = 1200)

#parnames(trt_group_m2_fit$brmsfit)
print_hmm_fit_summary(trt_group_m2_fit)

And here are the estimates of the log ORs:

hmm_hypothesis_res_list[["trt_group_m2"]] <- evaluate_all_treatment_hypotheses(trt_group_m2_fit, model_subgroup = "simple_group", adjusted ="all treatments, supportive, age, sex", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_group_m2"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_group_m2_fit)
}
rm(trt_group_m2_fit)

Treatments + age + sex + hospital

Here the model is further expanded to allow all the rates to vary between hospitals.

trt_group_m3_formula <- update.formula(trt_group_m2_formula,  ~  . + (0 + .rate_id | hospital_id))
trt_group_m3_formula

Summary of the coefficients:

trt_group_m3 <- brmshmmdata(trt_group_m3_formula, serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, simple_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

trt_group_m3_fit <- brmhmm(trt_group_m3, cache_file = paste0(cache_dir,"/trt_group_m3.rds"), init = 0.1, control = list(adapt_delta = 0.95), iter = 1200)

print_hmm_fit_summary(trt_group_m3_fit)

We see that the data are consistent with quite high between-site differences (measured in standard deviations of the varying intercepts).

And the resulting log ORs are:

hmm_hypothesis_res_list[["trt_group_m3"]] <-  evaluate_all_treatment_hypotheses(trt_group_m3_fit, model_subgroup = "simple_group", adjusted ="all treatments, supportive, age, sex, hospital", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_group_m3"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_group_m3_fit)
}
rm(trt_group_m3_fit)

Treatments + age + sex + hospital, first wave only

This is the same model as above, but using only the patients from the first wave. Here is the summary of the coefficients.

serie_data_fw <- serie_data %>% filter(first_wave)
simple_initial_states_fw <- serie_data_fw %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) 


trt_group_m3_fw <- brmshmmdata(trt_group_m3_formula, serie_data_fw, 
                  simple_rate_data, trt_group_m1_hidden_state_data, simple_initial_states_fw, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

trt_group_m3_fw_fit <- brmhmm(trt_group_m3_fw, cache_file = paste0(cache_dir,"/trt_group_m3_fw.rds"), init = 0.1, control = default_control, iter = 1200)

print_hmm_fit_summary(trt_group_m3_fw_fit)

And the resulting log ORs/log duration ratios:

serie_data_28_fw = serie_data_28 %>% filter(first_wave)

hmm_hypothesis_res_list[["trt_group_m3_fw"]] <-  evaluate_all_treatment_hypotheses(trt_group_m3_fw_fit, model_subgroup = "simple_group", adjusted ="all treatments, supportive, age, sex, hospital, first_wave", model_check = "OK", cl = cl, serie_data_28 = serie_data_28_fw)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_group_m3_fw"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_group_m3_fw_fit)
}
rm(trt_group_m3_fw_fit)

Treatment + age + sex + hospital + comorbidities

For this model, the sum of comorbidities is added as a continuous predictor, resulting in the formula:

trt_group_m4_formula <-  update.formula(trt_group_m3_formula,  ~ . + (comorbidities_sum || rate_group)) 
trt_group_m4_formula

Where comorbidities_sum is a comorbidity score as described in Section \@ref(comorbiditiessum).

Below is the summary of the coefficients:

trt_group_m4 <- brmshmmdata(trt_group_m4_formula, serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, simple_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)


trt_group_m4_fit <- brmhmm(trt_group_m4, cache_file = paste0(cache_dir,"/trt_group_m4.rds"), init = 0.1, control = list(adapt_delta = 0.95), iter = 1200)

print_hmm_fit_summary(trt_group_m4_fit)

The sampler warns us of divergent transitions, indicating the posterior might not have been fully explored. This is most likely because we do not have enough data to estimate all the parameters. Here and in the following, we mark models that had such problems as "Suspicious" and do not report their results in the main text of the manuscript.

The resulting log ORs (which cannot be fully trusted):

hmm_hypothesis_res_list[["trt_group_m4"]] <-  evaluate_all_treatment_hypotheses(trt_group_m4_fit, model_subgroup = "simple_group", adjusted ="all treatments, supportive, age, sex, hospital, comorbidities", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_group_m4"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_group_m4_fit)
}
rm(trt_group_m4_fit)

Treatment + age + sex + hospital + comorbidities, first wave

Once again, we run the same model but for the first wave data only. Similarly to the previous case this results in divergent transitions and the model should not be fully trusted.

trt_group_m4_fw <- brmshmmdata(trt_group_m4_formula, serie_data_fw, 
                  simple_rate_data, trt_group_m1_hidden_state_data, simple_initial_states_fw, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)


trt_group_m4_fw_fit <- brmhmm(trt_group_m4_fw, cache_file = paste0(cache_dir,"/trt_group_m4_fw.rds"), init = 0.1, control = default_control, iter = 1200)

print_hmm_fit_summary(trt_group_m4_fw_fit)

And here are the resulting log ORs

hmm_hypothesis_res_list[["trt_group_m4_fw"]] <-  evaluate_all_treatment_hypotheses(trt_group_m4_fw_fit, model_subgroup = "simple_group", adjusted ="all treatments, supportive, age, sex, hospital, comorbidities, first_wave", model_check = "Suspicious", cl = cl, serie_data_28 = serie_data_28_fw)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_group_m4_fw"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_group_m4_fw_fit)
}
rm(trt_group_m4_fw_fit)
invisible(parallel::clusterEvalQ(cl, gc()))

Treatments - effect on individual rates

In this set of models the effects (especially the treatments) act directly on the individual rates.

Only treatments

For treatments only, this boils down to the following brms formula:

trt_rate_m1_hidden_state_data <- simple_hidden_state_data

trt_rate_m1_observed_state_data <- base_observed_state_data

trt_rate_m1_rate_data <- simple_rate_data

trt_rate_m1_initial_states <- simple_initial_states

trt_rate_m1_prior = trt_group_m1_prior

trt_rate_m1_formula <- ~ 0 + .rate_id +  (0 + best_supportive_care + took_hcq + took_az + took_favipiravir + took_convalescent_plasma || .rate_id)

trt_rate_m1_formula

Resulting in the following fitted coefficients:

trt_rate_m1 <- brmshmmdata(trt_rate_m1_formula, serie_data, 
                  trt_rate_m1_rate_data, simple_hidden_state_data, simple_initial_states, prior = trt_rate_m1_prior,
                  observed_state_data = base_observed_state_data)

trt_rate_m1_fit <- brmhmm(trt_rate_m1, cache_file = paste0(cache_dir,"/trt_rate_m1.rds"), control = default_control, init = 0.1)

print_hmm_fit_summary(trt_rate_m1_fit)
hmm_hypothesis_res_list[["trt_rate_m1"]] <-  evaluate_all_treatment_hypotheses(trt_rate_m1_fit, model_subgroup = "simple_rate", adjusted ="all treatments, supportive", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)
if(do_pp_checks) {
  do_common_pp_checks(trt_rate_m1_fit)
}
rm(trt_rate_m1_fit)

Treatments + age + sex

trt_rate_m2_formula <- update.formula(trt_rate_m1_formula,  ~ . +  (0 + age_norm + is_male || .rate_id))
trt_rate_m2 <- brmshmmdata(trt_rate_m2_formula, serie_data, 
                  trt_rate_m1_rate_data, trt_rate_m1_hidden_state_data, trt_rate_m1_initial_states, prior = trt_rate_m1_prior,
                  observed_state_data = trt_rate_m1_observed_state_data)

trt_rate_m2_fit <- brmhmm(trt_rate_m2, cache_file = paste0(cache_dir,"/trt_rate_m2.rds"), init = 0.1, control = default_control)

print_hmm_fit_summary(trt_rate_m2_fit)

And those log ORs for the outcomes:

hmm_hypothesis_res_list[["trt_rate_m2"]] <-  evaluate_all_treatment_hypotheses(trt_rate_m2_fit, model_subgroup = "simple_rate", adjusted ="all treatments, supportive, age, sex", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_rate_m2"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_rate_m2_fit)
}
rm(trt_rate_m2_fit)

Treatments + age + sex + hospital

We allow the rates to differ by hospital site with this formula:

trt_rate_m3_formula <- update.formula(trt_rate_m2_formula,  ~  . + (0 + .rate_id | hospital_id))
trt_rate_m3_formula

Those are the resulting fitted coefficients:

trt_rate_m3 <- brmshmmdata(trt_rate_m3_formula, serie_data, 
                  trt_rate_m1_rate_data, trt_rate_m1_hidden_state_data, trt_rate_m1_initial_states, prior = trt_rate_m1_prior,
                  observed_state_data = trt_rate_m1_observed_state_data)

trt_rate_m3_fit <- brmhmm(trt_rate_m3, cache_file = paste0(cache_dir,"/trt_rate_m3.rds"), init = 0.1, control = list(adapt_delta = 0.95), iter = 1500)

print_hmm_fit_summary(trt_rate_m3_fit)

And the resulting estimates:

hmm_hypothesis_res_list[["trt_rate_m3"]] <-  evaluate_all_treatment_hypotheses(trt_rate_m3_fit, model_subgroup = "simple_rate", adjusted ="all treatments, supportive, age, sex, hospital", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_rate_m3"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_rate_m3_fit)
}
rm(trt_rate_m3_fit)

Treatments + age + sex + hospital, first_Wave

This is same as the previous model, but limiting only to first wave patients, the fitted coefficients:

trt_rate_m3_fw <- brmshmmdata(trt_rate_m3_formula, serie_data_fw, 
                  trt_rate_m1_rate_data, trt_rate_m1_hidden_state_data, simple_initial_states_fw, prior = trt_rate_m1_prior,
                  observed_state_data = trt_rate_m1_observed_state_data)

trt_rate_m3_fw_fit <- brmhmm(trt_rate_m3_fw, cache_file = paste0(cache_dir,"/trt_rate_m3_fw.rds"), init = 0.1, control = list(adapt_delta = 0.95), iter = 1500)

print_hmm_fit_summary(trt_rate_m3_fw_fit)

And the log(OR) estimates:

hmm_hypothesis_res_list[["trt_rate_m3_fw"]] <-  evaluate_all_treatment_hypotheses(trt_rate_m3_fw_fit, model_subgroup = "simple_rate", adjusted ="all treatments, supportive, age, sex, hospital, first_wave", model_check = "OK", cl = cl, serie_data_28 = serie_data_28_fw)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_rate_m3_fw"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_rate_m3_fw_fit)
}
rm(trt_rate_m3_fw_fit)

Treatment + age + sex + hospital + comorbidities

Adding comorbidities, the complete formula:

trt_rate_m4_formula <-  update.formula(trt_rate_m3_formula, ~ . + (comorbidities_sum || .rate_id))
trt_rate_m4_formula

Where comorbidities_sum is a comorbidity score as described in Section \@ref(comorbiditiessum).

Coefficient estimates - note the divergent transitions, indicating the results are not completely trustworthy:

trt_rate_m4 <- brmshmmdata(trt_rate_m4_formula, serie_data, 
                  trt_rate_m1_rate_data, trt_rate_m1_hidden_state_data, trt_rate_m1_initial_states, prior = trt_rate_m1_prior,
                  observed_state_data = trt_rate_m1_observed_state_data)

trt_rate_m4_fit <- brmhmm(trt_rate_m4, cache_file = paste0(cache_dir,"/trt_rate_m4.rds"), init = 0.1, control = default_control, iter = 1500)

print_hmm_fit_summary(trt_rate_m4_fit)

Log odds-ratio estimates:

hmm_hypothesis_res_list[["trt_rate_m4"]] <- evaluate_all_treatment_hypotheses(trt_rate_m4_fit, model_subgroup = "simple_rate", adjusted ="all treatments, supportive, age, sex, hospital, comorbidities", model_check = "Suspicious", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["trt_rate_m4"]])
if(do_pp_checks) {
  do_common_pp_checks(trt_rate_m4_fit)
}
rm(trt_rate_m4_fit)
invisible(parallel::clusterEvalQ(cl, gc()))

Markers

In the marker models, we use the simple model and investigate effects on rate groups. We filter only for patients that at had the marker measured at least once. We use the peak of the marker so far as a predictor in the model and investigate its estimated coefficient. Unfortunately, none of the results are trustworthy as the sampler encountered divergent transitions.

D-dimer + hospital

The formula is:

d_dimer_serie_data <- serie_data %>% 
  group_by(.serie) %>%
  filter(any(!is.na(d_dimer))) %>%
  ungroup() %>%
  compute_marker_peak(column_name = "d_dimer", new_column_name = "peak_d_dimer", initial_value = 0) %>%
  mutate(log_peak_d_dimer = log(peak_d_dimer + 1) - 4)

d_dimer_initial_states <- d_dimer_serie_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed)


d_dimer_m1_formula <-  ~ 0 + .rate_id +  (0 + best_supportive_care || rate_group)  + (0 + .rate_id | hospital_id) + (log_peak_d_dimer || rate_group) 


d_dimer_m1_formula

And here are the coefficient estimates - note the divergent transitions, indicating the results are not completely trustworthy:

d_dimer_m1 <- brmshmmdata(d_dimer_m1_formula, d_dimer_serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, d_dimer_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

d_dimer_m1_fit <- brmhmm(d_dimer_m1, cache_file = paste0(cache_dir,"/d_dimer_m1.rds"), init = 0.1, control = default_control, iter = 1500)

print_hmm_fit_summary(d_dimer_m1_fit)
draws <- tidybayes::tidy_draws(d_dimer_m1_fit$brmsfit)
hmm_hypothesis_res_list[["d_dimer_m1_death"]] <-
  bayesian_hypothesis_res_from_draws(
    draws = draws$`r_rate_group[death,log_peak_d_dimer]`,
    model = "HMMsimple_group",
    estimand = "log(HR)",
    hypothesis = hypotheses$d_dimer_death,
    adjusted = "supportive, hospital",
    model_check = "Suspicious"
  )
if(do_pp_checks) {
  do_common_pp_checks(d_dimer_m1_fit)
}
rm(d_dimer_m1_fit)

D-dimer + age + sex + comorbidities

Including comorbidites in the model gives the formula:

d_dimer_m2_formula <-  update.formula(d_dimer_m1_formula, ~ . + (0 + age_norm + is_male || rate_group))
d_dimer_m2_formula

Where comorbidities_sum is a comorbidity score as described in Section \@ref(comorbiditiessum).

And here are the coefficient estimates - note the divergent transitions, indicating the results are not completely trustworthy:

d_dimer_m2 <- brmshmmdata(d_dimer_m2_formula, d_dimer_serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, d_dimer_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

d_dimer_m2_fit <- brmhmm(d_dimer_m2, cache_file = paste0(cache_dir,"/d_dimer_m2.rds"), init = 0.1, control = default_control, iter = 1500)

print_hmm_fit_summary(d_dimer_m2_fit)
draws <- tidybayes::tidy_draws(d_dimer_m2_fit$brmsfit)
hmm_hypothesis_res_list[["d_dimer_m2_death"]] <-
  bayesian_hypothesis_res_from_draws(
    draws = draws$`r_rate_group[death,log_peak_d_dimer]`,
    model = "HMMsimple_group",
    estimand = "log(HR)",
    hypothesis = hypotheses$d_dimer_death,
    adjusted = "supportive, age, sex, hospital",
    model_check = "Suspicious"
  )
rm(d_dimer_m2_fit)

IL-6 + hospital

We further look into Interleukin 6 using the formula:

IL_6_serie_data <- serie_data %>% 
  group_by(.serie) %>%
  filter(any(!is.na(IL_6))) %>%
  ungroup() %>%
  compute_marker_peak(column_name = "IL_6", new_column_name = "peak_IL_6", initial_value = 1) %>%
  mutate(log_peak_IL_6 = log(peak_IL_6))

IL_6_initial_states <- IL_6_serie_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed)


IL_6_m1_formula <- ~ 0 + .rate_id +  (0 + best_supportive_care || rate_group)  + (0 + .rate_id | hospital_id) + (log_peak_IL_6 || rate_group) 
IL_6_m1_formula

Giving us the following coefficients - note the divergent transitions, indicating the results are not completely trustworthy:

IL_6_m1 <- brmshmmdata(IL_6_m1_formula, IL_6_serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, IL_6_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

IL_6_m1_fit <- brmhmm(IL_6_m1, cache_file = paste0(cache_dir,"/IL_6_m1.rds"), init = 0.1, control = default_control, iter = 1500)

print_hmm_fit_summary(IL_6_m1_fit)
draws <- tidybayes::tidy_draws(IL_6_m1_fit$brmsfit)
hmm_hypothesis_res_list[["IL_6_m1_death"]] <-
  bayesian_hypothesis_res_from_draws(
    draws = draws$`r_rate_group[death,log_peak_IL_6]`,
    model = "HMMsimple_group",
    estimand = "log(HR)",
    hypothesis = hypotheses$IL_6_death,
    adjusted = "supportive, hospital"
  )
if(do_pp_checks) {
  do_common_pp_checks(IL_6_m1_fit)
}
rm(IL_6_m1_fit)

IL_6 + age + sex + comorbidities

And finally adding comorbidities with the formula:

IL_6_m2_formula <-  update.formula(IL_6_m1_formula, ~ . + (0 + age_norm + is_male || rate_group))
IL_6_m2_formula

Where comorbidities_sum is a comorbidity score as described in Section \@ref(comorbiditiessum).

Gives us those estimates - note the divergent transitions, indicating the results are not completely trustworthy:

IL_6_m2 <- brmshmmdata(IL_6_m2_formula, IL_6_serie_data, 
                  simple_rate_data, trt_group_m1_hidden_state_data, IL_6_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = trt_group_m1_observed_state_data)

IL_6_m2_fit <- brmhmm(IL_6_m2, cache_file = paste0(cache_dir,"/IL_6_m2.rds"), init = 0.1, control = default_control, iter = 1500)

print_hmm_fit_summary(IL_6_m2_fit)
draws <- tidybayes::tidy_draws(IL_6_m2_fit$brmsfit)
hmm_hypothesis_res_list[["IL_6_m2_death"]] <-
  bayesian_hypothesis_res_from_draws(
    draws = draws$`r_rate_group[death,log_peak_IL_6]`,
    model = "HMMsimple_group",
    estimand = "log(HR)",
    hypothesis = hypotheses$IL_6_death,
    adjusted = "supportive, age, sex, hospital"
  )
rm(IL_6_m2_fit)
invisible(parallel::clusterEvalQ(cl, gc()))

Complex: Treatments - effects on rate groups

complex_rate_data <- rbind(
  tibble(.from = paste0(c("Oxygen", "Ventilated"), "_worsening"), .to = "Death", rate_group = "death"),
  tibble(.from = "AA_improving", 
         .to = "Discharged", 
         rate_group = "improve_one"),
  tibble(.from = paste0(breathing_s_levels[2 : length(breathing_s_levels)], "_improving"), 
         .to = paste0(breathing_s_levels[1 : (length(breathing_s_levels) - 1)] , "_improving"), 
         rate_group = "improve_one"),
  tibble(.from = paste0(breathing_s_levels[1 : (length(breathing_s_levels) - 1)], "_worsening"),
         .to = paste0(breathing_s_levels[2 : length(breathing_s_levels)], "_worsening"),
         rate_group = "worsen_one"),
  tibble(.from = paste0(breathing_s_levels, "_worsening"),
         .to = paste0(breathing_s_levels, "_improving"),
         rate_group = "to_improving"),
  tibble(.from = paste0(breathing_s_levels, "_improving"),
         .to = paste0(breathing_s_levels, "_worsening"),
         rate_group = "to_worsening")
) %>%
  mutate(.rate_id = factor(paste0(.from, "_", .to), levels = paste0(.from, "_", .to)),
  ) %>%
  mutate(rate_group = fct_relevel(rate_group, "improve_one"))

complex_hidden_state_data <- rbind(
  tibble(base = breathing_s_levels) %>% mutate(corresponding_obs = base) %>% 
    crossing(tibble(state_group = c("_improving", "_worsening"))) %>% 
    mutate(id = factor(paste0(base, state_group))) %>%
    select(-base, -state_group),
  tibble(id = c("Death","Discharged")) %>% mutate(corresponding_obs = id)
)


complex_initial_states <- serie_data %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) %>% paste0(., "_worsening")
complex_rate_data %>% knitr::kable(booktabs = TRUE, caption="Rates used in the complex model.") 

Here we use the complex model and let treatments have effect on the rate groups. The models have the same formulae as the simple models, only the rate (and state) definitions change. The rates used are shown in Table \@ref(tab:ratescomplex).

Only treatments

Using treatments only with the formula:

trt_group_m1_formula

Fitted coefficients:

c_trt_m1 <- brmshmmdata(trt_group_m1_formula, serie_data, 
                  complex_rate_data, complex_hidden_state_data, complex_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = base_observed_state_data)

c_trt_m1_fit <- brmhmm(c_trt_m1, cache_file = paste0(cache_dir,"/c_trt_m1.rds"), control = default_control, init = 0.1, iter = 1500)

print_hmm_fit_summary(c_trt_m1_fit)

Estimates:

hmm_hypothesis_res_list[["c_trt_rate_m1"]] <- evaluate_all_treatment_hypotheses(c_trt_m1_fit, model_subgroup = "complex_group", adjusted ="all treatments, supportive", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["c_trt_rate_m1"]])
if(do_pp_checks) {
  do_common_pp_checks(c_trt_m1_fit)
}
rm(c_trt_m1_fit)

Treatments + age + sex

Adding age and sex with formula:

trt_group_m2_formula

Fitted coefficients:

c_trt_m2 <- brmshmmdata(trt_group_m2_formula, serie_data, 
                  complex_rate_data, complex_hidden_state_data, complex_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = base_observed_state_data)


c_trt_m2_fit <- brmhmm(c_trt_m2, cache_file = paste0(cache_dir,"/c_trt_m2.rds"), init = 0.1, control = default_control, iter = 1200)

#parnames(c_trt_m2_fit$brmsfit)
print_hmm_fit_summary(c_trt_m2_fit)

Estimates:

hmm_hypothesis_res_list[["c_trt_rate_m2"]] <- evaluate_all_treatment_hypotheses(c_trt_m2_fit, model_subgroup = "complex_group", adjusted ="all treatments, supportive, age, sex", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["c_trt_rate_m2"]])
if(do_pp_checks) {
  do_common_pp_checks(c_trt_m2_fit)
}
rm(c_trt_m2_fit)

Treatments + age + sex + hospital

Including hospital as well with the formula:

trt_group_m3_formula

Fitted coefficients:

c_trt_m3 <- brmshmmdata(trt_group_m3_formula, serie_data, 
                  complex_rate_data, complex_hidden_state_data, complex_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = base_observed_state_data)

c_trt_m3_fit <- brmhmm(c_trt_m3, cache_file = paste0(cache_dir,"/c_trt_m3.rds"), init = 0.1, control = default_control, iter = 1200)

print_hmm_fit_summary(c_trt_m3_fit)

Estimates:

hmm_hypothesis_res_list[["c_trt_rate_m3"]] <- evaluate_all_treatment_hypotheses(c_trt_m3_fit, model_subgroup = "complex_group", adjusted ="all treatments, supportive, age, sex, hospital", model_check = "OK", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["c_trt_rate_m3"]])
if(do_pp_checks) {
  do_common_pp_checks(c_trt_m3_fit)
}
rm(c_trt_m3_fit)

Treatments + age + sex + hospital, first wave

The same model as above, but using only first wave patients. Fitted coefficients:

complex_initial_states_fw <- serie_data_fw %>% filter(.time == 1) %>% arrange(.serie) %>% pull(.observed) %>% paste0(., "_worsening")

c_trt_m3_fw <- brmshmmdata(trt_group_m3_formula, serie_data_fw, 
                  complex_rate_data, complex_hidden_state_data, complex_initial_states_fw, prior = trt_group_m1_prior,
                  observed_state_data = base_observed_state_data)

c_trt_m3_fw_fit <- brmhmm(c_trt_m3_fw, cache_file = paste0(cache_dir,"/c_trt_m3_fw.rds"), init = 0.1, control = default_control, iter = 1200)

print_hmm_fit_summary(c_trt_m3_fw_fit)

Estimates:

hmm_hypothesis_res_list[["c_trt_rate_m3_fw"]] <- evaluate_all_treatment_hypotheses(c_trt_m3_fw_fit, model_subgroup = "complex_group", adjusted ="all treatments, supportive, age, sex, hospital, first_wave", model_check = "OK", cl = cl, serie_data_28 = serie_data_28_fw)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["c_trt_rate_m3_fw"]])
if(do_pp_checks) {
  do_common_pp_checks(c_trt_m3_fw_fit)
}
rm(c_trt_m3_fw_fit)
# Estimates for FPV for metaanalysis
hmm_fit <- c_trt_m3_fw_fit
numerator_serie_data <- serie_data_28_fw %>% mutate(took_favipiravir = 1)
denominator_serie_data <- serie_data_28_fw %>% mutate(took_favipiravir = 0)


  combined_data <- hmm_fit$data
  combined_data$serie_data = rbind(
    numerator_serie_data %>% mutate(.serie = paste0("__num_", .serie), is_numerator = TRUE),
    denominator_serie_data %>% mutate(.serie = paste0("__denom_", .serie), is_numerator = FALSE)
  ) %>%
    mutate(.serie = factor(.serie))

  combined_data$initial_states = rep(hmm_fit$data$initial_states, 2)


  oxygen_states_ids <- which(hmm_fit$data$hidden_state_data$corresponding_obs == "Oxygen")
  ventilation_states_ids <- which(hmm_fit$data$hidden_state_data$corresponding_obs == "Ventilated")

  numerator_indices <- 1:length(hmm_fit$data$initial_states) + length(hmm_fit$data$initial_states)
  denominator_indices <- 1:length(hmm_fit$data$initial_states)

  numerator_indices_oxygen <- numerator_indices[hmm_fit$data$initial_states == "AA_worsening"]
  denominator_indices_oxygen <- denominator_indices[hmm_fit$data$initial_states == "AA_worsening"]

  numerator_indices_ventilated <- numerator_indices[hmm_fit$data$initial_states != "Ventilated_worsening"]
  denominator_indices_ventilated <- denominator_indices[hmm_fit$data$initial_states != "Ventilated_worsening"]


states_rep <- 4
states_list <- list()
for(i in 1:states_rep) {
  states_list[[i]] <- posterior_epred_rect(hmm_fit, newdata = combined_data, method = posterior_epred_rect_simulate, cores = 6)
}

states <- do.call(abind::abind, args = c(states_list, list(along = 3)))

  needed_ventilation <- apply(states, MARGIN = c(2,3), FUN = function(x) { any(x %in% ventilation_states_ids) })
  needed_oxygen <- apply(states, MARGIN = c(2,3), FUN = function(x) { any(x %in% oxygen_states_ids) })

  samples_oxygen_probs_numerator <- 
    colMeans(needed_oxygen[numerator_indices_oxygen, ])
  samples_oxygen_probs_denominator <- 
    colMeans(needed_oxygen[denominator_indices_oxygen, ])

  log_oxygen_odds_numerator <- log(samples_oxygen_probs_numerator) - log1p(-samples_oxygen_probs_numerator)
  log_oxygen_odds_denominator <- log(samples_oxygen_probs_denominator) - log1p(-samples_oxygen_probs_denominator)


  log_OR_oxygen <- log_oxygen_odds_numerator - log_oxygen_odds_denominator  


    samples_ventilation_probs_numerator <- 
    colMeans(needed_ventilation[numerator_indices_ventilated, ])
  samples_ventilation_probs_denominator <- 
    colMeans(needed_ventilation[denominator_indices_ventilated, ])

  log_ventilation_odds_numerator <- log(samples_ventilation_probs_numerator) - log1p(-samples_ventilation_probs_numerator)
  log_ventilation_odds_denominator <- log(samples_ventilation_probs_denominator) - log1p(-samples_ventilation_probs_denominator)


  log_OR_ventilation <- log_ventilation_odds_numerator - log_ventilation_odds_denominator  

  # There is a very tiny fraction of NaNs where no patient is predicted to need ventilation. remove those.
  log_OR_ventilation <- log_OR_ventilation[is.finite(log_OR_ventilation)]

  death_res <- hmm_hypothesis_res_list[["c_trt_rate_m3_fw"]] %>% filter(hypothesis == "favipiravir_death")

res_fpv <- tibble::tribble(
  ~event, ~estimate, ~se, ~`lower95`, ~`upper95`,
  "Death",  death_res$point_estimate, death_res$sd, death_res$ci_low, death_res$ci_high,
  "Oxygen", mean(log_OR_oxygen), sd(log_OR_oxygen), quantile(log_OR_oxygen, 0.025), quantile(log_OR_oxygen, 0.975),
  "Ventilated", mean(log_OR_ventilation), sd(log_OR_ventilation), quantile(log_OR_ventilation, 0.025, na.rm=TRUE), quantile(log_OR_ventilation, 0.975, na.rm=TRUE)
  )

write_csv(res_fpv, file = here::here("local_temp_data", "fpv_meta","main_events.csv"))

Treatment + age + sex + hospital + comorbidities

Adding comorbidities with the formula:

trt_group_m4_formula

Where comorbidities_sum is a comorbidity score as described in Section \@ref(comorbiditiessum).

Fitted coefficients - note the divergent transitions, indicating the results are not completely trustworthy:

c_trt_m4 <- brmshmmdata(trt_group_m4_formula, serie_data, 
                  complex_rate_data, complex_hidden_state_data, complex_initial_states, prior = trt_group_m1_prior,
                  observed_state_data = base_observed_state_data)  


c_trt_m4_fit <- brmhmm(c_trt_m4, cache_file = paste0(cache_dir,"/c_trt_m4.rds"), init = 0.1, control = default_control, iter = 1200)

print_hmm_fit_summary(c_trt_m4_fit)

Estimates (not completely trustworthy due to the divergent transitions):

hmm_hypothesis_res_list[["c_trt_rate_m4"]] <- evaluate_all_treatment_hypotheses(c_trt_m4_fit, model_subgroup = "complex_group", adjusted ="all treatments, supportive, age, sex, hospital, comorbidities", model_check = "Suspicious", cl = cl, serie_data_28 = serie_data_28)

print_hmm_hypothesis_res(hmm_hypothesis_res_list[["c_trt_rate_m4"]])
if(do_pp_checks) {
  do_common_pp_checks(c_trt_m4_fit)
}
rm(c_trt_m4_fit)
invisible(parallel::clusterEvalQ(cl, gc()))
hmm_hypothesis_res_all <- do.call(rbind, hmm_hypothesis_res_list)
write_csv(hmm_hypothesis_res_all, path = here::here("manuscript", "hmm_res.csv"))


cas-bioinf/covid19retrospective documentation built on Sept. 7, 2021, 6:19 p.m.