R/pfilter_fit.R

Defines functions pfilter_fit

Documented in pfilter_fit

#' Particle filtering
#' @author Ta-Chou Ng
#'
#' @param res a list containing the case_data data.tables generated by scenario_sim
#'
#' @return data.table of outbreak summary statistics
#' @export
#'
#' @importFrom data.table rbindlist as.data.table
#' @importFrom future.apply future_sapply
#' @importFrom future plan availableCores
#'
pfilter_fit <- function(model = NULL, fitted.paras = NULL, fixed.paras = NULL,
                        measure = NULL, max.iter = 200, n.sam = 1000, eps = .1){


  sams <- pfilter_lhs(fitted.paras, n = n.sam)
  sams0 <- sams

  ksDs.med_vect <- c(0, 0)
  endloop <- F
  iter <- 1
  future::plan(future::multiprocess, workers = max(future::availableCores() - 2, 1))
  while(!endloop & iter <= max.iter){

    wrap.f <- function(x){
      # xx <- purrr::map_dbl(1:10, function(i) purrr::lift_dl(model)(x) %>% measure() )
      # median(xx)
      purrr::lift_dl(model)(x) %>% measure()
      }

    paraLS <- lapply(1:n.sam,  function(i){
      c(fixed.paras, as.list(sams[i,]))
    })

    ksDs <- future.apply::future_sapply(paraLS, wrap.f)

    cat("\n iteration ", iter, " completed.")
    iter <- iter +1

    #sel.id <- which(ksDs < max(ksDs,na.rm = T)*.5) #
    #ksDs <- ksDs[sel.id]

    ksDs.med <- median(ksDs)
    ksDs.med_1 <- ksDs.med_vect[length(ksDs.med_vect)]
    ksDs.med_2 <- ksDs.med_vect[length(ksDs.med_vect)-1]
    ksDs.med_vect <- c(ksDs.med_vect, ksDs.med)

    pmar <- .1
    endloop <- ksDs.med > ksDs.med_1*(1-pmar) & ksDs.med < ksDs.med_1*(1+pmar) &
       ksDs.med > ksDs.med_2*(1-pmar) & ksDs.med < ksDs.med_2*(1+pmar) & ksDs.med <= eps
    if(endloop){
      cat("\n Converged.")
      break
    }

    #sams <- sams[sel.id, ]
    #new.sams <- sams[sample.int(NROW(sams), n.sam-NROW(sams),prob = 1/ksDs, replace = T),]
    #new.sams <- new.sams[, lapply(.SD, function(x){ x*runif(length(x), min = 0.95, 1.05) })]
    #sams <- rbind(sams, new.sams)
    new.sams <- sams[sample.int(NROW(sams), n.sam, prob = 1/sqrt(ksDs), replace = T),]
    sams <- new.sams[, lapply(.SD, function(x){ x*runif(length(x), min = 0.9, 1.1) })]
  }
  future::plan(future::sequential)

  return(list(
    fitted.sams = sams,
    prior.sams = sams0,
    measure.sumvect = ksDs.med_vect, measure.stats = ksDs))
}
dachuwu/DTQbp documentation built on Dec. 19, 2021, 8:01 p.m.