inst/ROPE_rt.R

library(writexl)
library(readxl)
library(dplyr)
library(rtestimate)
library(tidyverse)
library(patchwork)
library(lazymcmc)
library(virosolver)
library(foreach)
library(doParallel)
library(ggplot2)
library(extraDistr)
library(ggthemes)
library(odin) ## install from CRAN
library(fitdistrplus)



prior_func_seir <- function(pars,...){
  ## Ct model priors
  obs_sd_prior <- dnorm(pars["obs_sd"], means[which(names(means) == "obs_sd")], sds_seir["obs_sd"],log=TRUE)
  viral_peak_prior <- dnorm(pars["viral_peak"], means[which(names(means) == "viral_peak")], sds_seir["viral_peak"],log=TRUE)
  wane_2_prior <- dnorm(pars["wane_rate2"],means[which(names(means) == "wane_rate2")],sds_seir["wane_rate2"],log=TRUE)
  tswitch_prior <- dnorm(pars["t_switch"],means[which(names(means) == "t_switch")],sds_seir["t_switch"],log=TRUE)
  level_prior <- dnorm(pars["level_switch"],means[which(names(means) == "level_switch")],sds_seir["level_switch"],log=TRUE)
  ## Beta prior on the prob_detect parameter to ensure between 0 and 1
  beta1_mean <- means[which(names(means) == "prob_detect")]
  beta1_sd <- sds_seir["prob_detect"]
  beta_alpha <- ((1-beta1_mean)/beta1_sd^2 - 1/beta1_mean)*beta1_mean^2
  beta_beta <- beta_alpha*(1/beta1_mean - 1)
  beta_prior <- dbeta(pars["prob_detect"],beta_alpha,beta_beta,log=TRUE)
  ## SEIR model priors
  incu_prior <- dlnorm(pars["incubation"],log(means[which(names(means) == "incubation")]), sds_seir["incubation"], TRUE)
  infectious_prior <- dlnorm(pars["infectious"],log(means[which(names(means) == "infectious")]),sds_seir["infectious"],TRUE)

  ## Sum up
  obs_sd_prior + viral_peak_prior +
    wane_2_prior + tswitch_prior + level_prior + beta_prior +
    incu_prior + infectious_prior
}
# compare
true_rt <- function(rope) {
  true <- read_excel('~/Desktop/prism_use/rt2/true_rt_R2.xlsx')
  true_rt_band <- true%>%
    mutate(upper_band = true_rt + 0.6) %>%  # set the true rt range as values(rt) +- 0.5 we accept
    mutate(lower_band = true_rt - 0.6) %>%
    dplyr::select(c(true_rt, upper_band, lower_band))
  return(true_rt_band)
}


epiestim_rt <- function(percentage, true_rt_band) {
  seir_dynamic <- readxl::read_excel('~/Desktop/prism_use/rt2/incidence_R2.xlsx')
  t_start <- seq(2, length(seir_dynamic$incidence[12:125]) - 13)
  t_end <- t_start + 13
  res_parametric_si <- rtestimate::estimate_R(seir_dynamic$incidence[12:125],
                                  method="parametric_si",
                                  config = make_config(list(
                                    t_start = t_start,
                                    t_end = t_end,
                                    mean_si = 4.4,
                                    std_si = 3.0))
  )
  # epiestim_rt_plot(res_parametric_si, legend = FALSE, what='R')
  res <- res_parametric_si$simu_lists %>%
    apply(2, function(x) ifelse(x >= true_rt_band[25:124,]$lower_band & x <= true_rt_band[25:124,]$upper_band, 1, 0)) %>%
    colSums(.== 1) # how many dot in lines are accepted
  res <- res/nrow(res_parametric_si$simu_lists)# the proportion of dot contained in each lines
  # the result of epiestim rt curves ROPE
  res_epiestim <- length(res[res >= percentage])/length(res)
  return(res_epiestim)
}
#write.csv(res_parametric_si$simu_lists, '~/Desktop/prism_use/rope_use/epiestim_rt.csv')

viral_load_rt <- function(percentage, true_rt_band) {

  ct_data <- read.csv('~/Desktop/prism_use/rt2/ct_r2.csv')
  load('data/R0_partab.rda')
  means <- R0_partab$values
  names(means) <- R0_partab$names
  ## Set standard deviations of prior distribution
  sds_seir <- c("obs_sd"=0.5,"viral_peak"=2,
                "wane_rate2"=1,"t_switch"=3,"level_switch"=1,
                "prob_detect"=0.03,
                "incubation"=0.25, "infectious"=0.5)

  incidence_function <- solveSEIRModel_lsoda_wrapper
  ## Create the posterior function used in the MCMC framework
  posterior_func <- create_posterior_func(parTab=R0_partab,
                                          data=ct_data,
                                          PRIOR_FUNC=prior_func_seir,
                                          INCIDENCE_FUNC=incidence_function,
                                          use_pos=FALSE) ## Important argument, see text

  ## Test with default parameters to find the log likelihood
  posterior_func(R0_partab$values)
  ##    obs_sd
  ## -6006.231

  ct_data_use <- ct_data %>% filter(t == 69)

  ## Function from the virosolver package to generate random starting parameter values that return a finite likelihood
  start_tab <- generate_viable_start_pars(R0_partab,ct_data_use,
                                          create_posterior_func,
                                          incidence_function,
                                          prior_func_seir)
  covMat <- diag(nrow(start_tab))
  mvrPars <- list(covMat,2.38/sqrt(nrow(start_tab[start_tab$fixed==0,])),w=0.8)
  mcmc_pars <- c("iterations"=20000,"popt"=0.234,"opt_freq"=2000,
                 "thin"=1000,"adaptive_period"=100000,"save_block"=100)
  chains <- load_mcmc_chains(location="./data/mcmc_chains/readme_single_cross_section",
                             parTab=start_tab,
                             burnin=mcmc_pars["adaptive_period"],
                             chainNo=FALSE,
                             unfixed=FALSE,
                             multi=TRUE)
  chain_comb <- chains$chain %>% as_tibble() %>% mutate(sampno=1:n()) %>% as.data.frame()


  samps <- sample(unique(chain_comb$sampno),20)
  dat_inc <- NULL
  solve_times <- 0:200
  population <- 8000
  for(ii in seq_along(samps)){
    pars <- get_index_pars(chain_comb, samps[ii])
    inc_est <- incidence_function(pars,solve_times)[1:length(solve_times)]
    inc_est <- inc_est/sum(inc_est)
    tmp_inc <- tibble(t=solve_times,inc=inc_est)
    x <- tmp_inc$inc

    ## Get Rt
    seir_pars <- c("beta"=pars["R0"]*(1/pars["infectious"]),1/pars["incubation"],1/pars["infectious"])
    names(seir_pars) <- c("beta","sigma","gamma")
    init <- c(1-pars["I0"],0,pars["I0"],0)
    sol <- solveSEIRModel_lsoda(solve_times, init, seir_pars)
    S_compartment <- c(rep(1, pars["t0"]),sol[,2])[1:length(solve_times)]
    tmp_inc$Rt <- S_compartment*unname(pars["R0"])

    dat_inc[[ii]] <- tmp_inc
    dat_inc[[ii]]$samp <- ii
  }
  trajs <- do.call("bind_rows",dat_inc)


  # for ROPE calculation process
  viral_rope_frame <- trajs %>%
    dplyr::select(t, Rt) %>%
    group_by(t) %>%
    summarize(quantile(Rt, c(seq(0.01, 0.99, 0.01), 0.999), na.rm=T), .groups = 'keep') %>%
    ungroup() %>%
    data.frame() %>%
    mutate(t = t + 1) %>%
    mutate(flag = rep(1:100, 201))
  names(viral_rope_frame) <- c('t', 'Rt', 'flag')
  viral_rope <- viral_rope_frame %>%
    reshape2::dcast(t~flag, value.var = 'Rt') %>%
    dplyr::select(-t) %>%
  #write.csv(viral_rope, '~/Desktop/prism_use/rope_use/viral_load_rt.csv')
    apply(2, function(x) ifelse(x >= true_rt_band$lower_band & x <= true_rt_band$upper_band, 1, 0)) %>%
    colSums(.== 1) # how many dot in lines are accepted
  viral_rope <- viral_rope/max(viral_rope_frame$t)# the proportion of dot contained in each lines
  # the result of viral_load rt curves ROPE
  res_viral_load <- length(viral_rope[viral_rope >= 0.95])/length(viral_rope)
  return(viral_rope_frame) # res_viral_load
}


true_rt_band <- true_rt(rope = 0.5)
epiestim_rt(0.95, true_rt_band)
viral_load_rt(0.95, true_rt_band)







# create every single data before calculate mean & upperband & lowerband
# origin epiestim
seir_dynamic <- readxl::read_excel('~/Desktop/prism_use/rt10/incidence_R10.xlsx')
t_start <- seq(2, length(seir_dynamic$incidence[12:60]) - 13)
t_end <- t_start + 13
res_parametric_si <- rtestimate::estimate_R(seir_dynamic$incidence[12:60],
                                            method="parametric_si",
                                            config = make_config(list(
                                              t_start = t_start,
                                              t_end = t_end,
                                              mean_si = 4.4,
                                              std_si = 3.0)))
epiestim_rt_plot(res_parametric_si, legend = FALSE, what='R')
write.csv(res_parametric_si$simu_lists, '~/Desktop/for_modified_data/epiestim.csv')

# modified epiestim

# viral load
ct_data <- read.csv('~/Desktop/prism_use/rt10/ct_r10.csv')
load('data/R0_partab.rda')
R0_partab$values[R0_partab$names=='R0'] <- 10
Rt_quants <- viral_load_estimate_rt(ct_data, R0_partab)
viral_rope_frame <- Rt_quants[[2]] %>%
  dplyr::select(t, Rt) %>%
  group_by(t) %>%
  summarize(quantile(Rt, c(seq(0.01, 0.99, 0.01), 0.999), na.rm=T), .groups = 'keep') %>%
  ungroup() %>%
  data.frame() %>%
  mutate(t = t + 1) %>%
  mutate(flag = rep(1:100, 201))
names(viral_rope_frame) <- c('t', 'Rt', 'flag')
viral_rope <- viral_rope_frame %>%
  reshape2::dcast(t~flag, value.var = 'Rt') %>%
  dplyr::select(-t)
write_csv(viral_rope, '~/Desktop/for_modified_data/viral_load.csv')
RussellXu/rtestimate documentation built on Jan. 1, 2022, 7:18 p.m.