R/viral_load_rt.R

Defines functions viral_load_estimate_rt

Documented in viral_load_estimate_rt

#' Viral load estimate rt process
#'
#' @param ct_data
#' @param R0_partab
#'
#' @return Rt data frame with its error band;Rt plot;
#'
#' @family Viral Load
viral_load_estimate_rt <- function(ct_data, R0_partab){

cl <- parallel::makeCluster(4, setup_strategy = "sequential")
registerDoParallel(cl)


means <- R0_partab$values
names(means) <- R0_partab$names
## Set standard deviations of prior distribution
sds_seir <- c("obs_sd"=0.5,"viral_peak"=2,
              "wane_rate2"=1,"t_switch"=3,"level_switch"=1,
              "prob_detect"=0.03,
              "incubation"=0.25, "infectious"=0.5)

## Define a function that returns the log prior probability for a given vector of parameter
## values in `pars`, given the prior means and standard deviations set above.
prior_func_seir <- function(pars,...){
  ## Ct model priors
  obs_sd_prior <- dnorm(pars["obs_sd"], means[which(names(means) == "obs_sd")], sds_seir["obs_sd"],log=TRUE)
  viral_peak_prior <- dnorm(pars["viral_peak"], means[which(names(means) == "viral_peak")], sds_seir["viral_peak"],log=TRUE)
  wane_2_prior <- dnorm(pars["wane_rate2"],means[which(names(means) == "wane_rate2")],sds_seir["wane_rate2"],log=TRUE)
  tswitch_prior <- dnorm(pars["t_switch"],means[which(names(means) == "t_switch")],sds_seir["t_switch"],log=TRUE)
  level_prior <- dnorm(pars["level_switch"],means[which(names(means) == "level_switch")],sds_seir["level_switch"],log=TRUE)
  ## Beta prior on the prob_detect parameter to ensure between 0 and 1
  beta1_mean <- means[which(names(means) == "prob_detect")]
  beta1_sd <- sds_seir["prob_detect"]
  beta_alpha <- ((1-beta1_mean)/beta1_sd^2 - 1/beta1_mean)*beta1_mean^2
  beta_beta <- beta_alpha*(1/beta1_mean - 1)
  beta_prior <- dbeta(pars["prob_detect"],beta_alpha,beta_beta,log=TRUE)
  ## SEIR model priors
  incu_prior <- dlnorm(pars["incubation"],log(means[which(names(means) == "incubation")]), sds_seir["incubation"], TRUE)
  infectious_prior <- dlnorm(pars["infectious"],log(means[which(names(means) == "infectious")]),sds_seir["infectious"],TRUE)

  ## Sum up
  obs_sd_prior + viral_peak_prior +
    wane_2_prior + tswitch_prior + level_prior + beta_prior +
    incu_prior + infectious_prior
}

## Point to a function that expects a vector of named parameters and returns a vector of daily infection probabilities/incidence
incidence_function <- solveSEIRModel_lsoda_wrapper



## Create the posterior function used in the MCMC framework
posterior_func <- create_posterior_func(parTab=R0_partab,
                                        data=ct_data,
                                        PRIOR_FUNC=prior_func_seir,
                                        INCIDENCE_FUNC=incidence_function,
                                        use_pos=FALSE) ## Important argument, see text

## Test with default parameters to find the log likelihood
posterior_func(R0_partab$values)
##    obs_sd
## -6006.231

ct_data_use <- ct_data %>% filter(t == 69)

## Function from the virosolver package to generate random starting parameter values that return a finite likelihood
start_tab <- generate_viable_start_pars(R0_partab,ct_data_use,
                                        create_posterior_func,
                                        incidence_function,
                                        prior_func_seir)
covMat <- diag(nrow(start_tab))
mvrPars <- list(covMat,2.38/sqrt(nrow(start_tab[start_tab$fixed==0,])),w=0.8)
mcmc_pars <- c("iterations"=20000,"popt"=0.234,"opt_freq"=2000,
               "thin"=1000,"adaptive_period"=100000,"save_block"=100)
dir.create("mcmc_chains/readme_single_cross_section",recursive=TRUE)

##################################
## RUN THE MCMC FRAMEWORK

nchains <- 1
res <- foreach(chain_no=1:nchains,.packages = c("virosolver","lazymcmc","extraDistr","tidyverse","patchwork")) %dopar% {
  outputs <- run_MCMC(parTab=start_tab,
                      data=ct_data_use,
                      INCIDENCE_FUNC=incidence_function,
                      PRIOR_FUNC=prior_func_seir,
                      mcmcPars=mcmc_pars,
                      filename=paste0("./data/mcmc_chains/readme_single_cross_section/readme_seir_",chain_no),
                      CREATE_POSTERIOR_FUNC=create_posterior_func,
                      mvrPars=mvrPars,
                      use_pos=FALSE) ## Important argument
}

## Load in MCMC chains again
chains <- load_mcmc_chains(location="./data/mcmc_chains/readme_single_cross_section",
                           parTab=start_tab,
                           burnin=mcmc_pars["adaptive_period"],
                           chainNo=FALSE,
                           unfixed=FALSE,
                           multi=TRUE)
chain_comb <- chains$chain %>% as_tibble() %>% mutate(sampno=1:n()) %>% as.data.frame()


# predictions <- plot_prob_infection(chain_comb,
#                                    nsamps=20,
#                                    INCIDENCE_FUNC=incidence_function,
#                                    solve_times=0:max(ct_data_use$t),
#                                    obs_dat=ct_data_use,
#                                    true_prob_infection=example_seir_incidence)
# p_incidence_prediction <- predictions$plot + scale_x_continuous(limits=c(0,200))
# p_incidence_prediction



samps <- sample(unique(chain_comb$sampno),20)
dat_inc <- NULL
solve_times <- 0:200
population <- 8000
for(ii in seq_along(samps)){
  pars <- get_index_pars(chain_comb, samps[ii])
  inc_est <- incidence_function(pars,solve_times)[1:length(solve_times)]
  inc_est <- inc_est/sum(inc_est)
  tmp_inc <- tibble(t=solve_times,inc=inc_est)
  x <- tmp_inc$inc

  ## Get Rt
  seir_pars <- c("beta"=pars["R0"]*(1/pars["infectious"]),1/pars["incubation"],1/pars["infectious"])
  names(seir_pars) <- c("beta","sigma","gamma")
  init <- c(1-pars["I0"],0,pars["I0"],0)
  sol <- solveSEIRModel_lsoda(solve_times, init, seir_pars)
  S_compartment <- c(rep(1, pars["t0"]),sol[,2])[1:length(solve_times)]
  tmp_inc$Rt <- S_compartment*unname(pars["R0"])

  dat_inc[[ii]] <- tmp_inc
  dat_inc[[ii]]$samp <- ii
}

trajs <- do.call("bind_rows",dat_inc)

Rt_quants <- trajs %>% group_by(t) %>%
  summarize(lower_95=quantile(Rt, 0.025,na.rm=TRUE),
            median=median(Rt,na.rm=TRUE),
            upper_95=quantile(Rt, 0.975,na.rm=TRUE))
colnames(Rt_quants) <- c('time', 'lower95', 'median', 'upper95')# rename


## Rt plot
## Calculate true Rt from true model parameters

pars <- R0_partab$values
names(pars) <- R0_partab$names
epidemic_process <- simulate_seir_process(pars,solve_times,population)
## Widen solution and extract key variables
res <- epidemic_process$seir_outputs %>% pivot_wider(names_from="variable",values_from="value")
p_rt <- ggplot(Rt_quants) +
  geom_hline(yintercept=1,linetype="dashed",col="grey70") +
  geom_ribbon(aes(x=time,ymin=lower95,ymax=upper95,fill="95% CI"),alpha=0.3) +
  geom_line(aes(x=time,y=median,col="Estimate Rt")) +
  geom_line(data=res %>% filter(time <= max(Rt_quants$time)),
            aes(x=time,y=Rt,col="True Rt"),linetype="dashed") +
  scale_color_manual(name="",values=c("Estimate Rt"="black","True Rt"="blue")) +
  scale_fill_manual(name="",values=c("95% CI"="gray40")) +
  scale_x_continuous(limits=c(0,200)) +
  ylab("time-varing Rt") +
  xlab("Time") +
  theme_bw()
p_rt
# ggsave('./output/viral_load_rt.pdf', p_rt)
return(list(Rt_quants, trajs, p_rt))
}
RussellXu/rtestimate documentation built on Jan. 1, 2022, 7:18 p.m.