R/DistributionFit.R

#' General distribution fitting algorithms
#' @export
DistributionFit = R6::R6Class("DistributionFit", inherit=PassthroughFilesystemCache, public = list(

  models = NULL,
  censored = NULL,
  groupedDf = NULL,
  grps = NULL,
  fitData = NULL,
  fittedModels = NULL,
  bootstraps = NULL,
  samples = NULL,
  shifted = 0,
  
  #' @description New distribution fitter
  #' @param distributions - the distributions to fit
  #' @param shifted - a shift to apply to all data
  #' @param ... for compatibility
  #' @return the fitter
  initialize = function(distributions = c("weibull","gamma","lnorm","exp","norm","tnorm"), shifted = 0, ...) {
    super$initialize(...)
    self$setModels(distributions, shifted)
  },
  
  setModels = function(distributions = c("weibull","gamma","lnorm","exp","norm","tnorm"), shifted = 0) {
    tmp = DistributionFit$standardDistributions
    self$models = tmp[names(tmp) %in% distributions]
    self$shifted = shifted
    invisible(self)
  },
  
  fromUncensoredData = function(groupedDf, valueExpr = value, truncate = TRUE, bootstraps=100,seed=101,...) {
    dots = rlang::list2(...)
    self$censored = FALSE
    self$groupedDf = groupedDf
    self$grps = grps = groupedDf %>% groups()
    valueExpr=enexpr(valueExpr)
    errors = NULL
    self$fitData = groupedDf %>% mutate(value = !!valueExpr) %>% mutate(value = value+self$shifted) %>% select(!!!grps,value) %>% group_by(!!!grps)
    #TODO: cache this?
    out = self$fitData %>% group_modify( function(d,g,...) {
      if(any(d$value == 0)) warning("Input contains zero values after shift.")
      if(any(d$value < 0)) warning("Input contains negative values after shift.")
      tmp = data.frame()
      for (m in self$models) {
        #browser()
        values = d %>% pull(value)
        if (truncate) {
          if (any(values>m$support[2]|values<m$support[1])) message("Truncating data for ",m$name)
          values = values[values<m$support[2] & values>m$support[1]]
        }
        
        if(m$discrete) {
          values = round(values)
        }
        
        arg = dots
        arg$data = values 
        arg$distr = m$name
        arg$start = m$start
        arg$lower=unlist(m$lower)
        #browser()
        if (length(m$fix.arg)>0) {arg$fix.arg=m$fix.arg}
        
        dist = suppressWarnings({
          tryCatch(do.call(fitdistrplus::fitdist, args=arg), error = function(e) {errors <<- c(errors,e$message);NULL})
        })
        
        if (is.null(dist)) {
          message("No fit for ",m$name)
        } else {
          res = self$extractFitted(dist, isCensored = FALSE, bootstraps = bootstraps,seed = seed)
          tmp = tmp %>% bind_rows(res$fittedModels)
          # nasty hack coming up....
          bstr = res$bootstraps %>% full_join(g, by = character())
          self$bootstraps = self$bootstraps %>% bind_rows(bstr)
        }
      }
      return(tmp)
    })
    self$fittedModels = distributionsDefinition(out)
    self$samples = NULL
    warning(unique(errors))
    invisible(self)
  },
  
  fromCensoredData = function(groupedDf, lowerValueExpr, upperValueExpr, truncate = TRUE, bootstraps = 100, seed=101,...) {
    lowerValueExpr=enexpr(lowerValueExpr)
    upperValueExpr=enexpr(upperValueExpr)
    dots = rlang::list2(...)
    errors=NULL
    
    self$censored = TRUE
    self$groupedDf = groupedDf
    self$grps = grps = groupedDf %>% groups()
    self$fitData = groupedDf %>% mutate(left = !!lowerValueExpr, right = !!upperValueExpr)  %>% mutate(left = left+self$shifted, right=right+self$shifted) %>% select(!!!grps,left,right) %>% group_by(!!!grps)
    #TODO: cache this?
    out = self$fitData %>% group_modify( function(d,g,...) {
      if(any(d$left == 0,na.rm=TRUE)) warning("Input contains zero values after shift.")
      if(any(d$left < 0,na.rm=TRUE)) warning("Input contains negative values after shift.")
      tmp = data.frame()
      for (m in self$models) {
        #browser()
        values = d %>% select(left,right)
        oldLen = nrow(values)
        if (truncate) {
          
          values = values %>% 
            filter(is.na(left) | (left<m$support[2] & left>m$support[1])) %>%
            filter(is.na(right) | (right<m$support[2] & right>m$support[1]))
          if (oldLen > nrow(values)) message("Truncating data for ",m$name)
        }
        
        if(m$discrete) {
          values = values %>% mutate(
            left = round(left),
            right = round(right),
          )
        }
        arg = dots
        arg$censdata = as.data.frame(values)
        arg$distr = m$name
        arg$start = m$start
        arg$lower=unlist(m$lower)
        #browser()
        if (length(m$fix.arg)>0) {arg$fix.arg=m$fix.arg}
        
        dist = suppressWarnings({
          tryCatch(do.call(fitdistrplus::fitdistcens, args=arg), error = function(e) {errors <<- c(errors,e$message);NULL})
        })
        
        #browser()
        if (is.null(dist)) {
          message("No fit for ",m$name)
        } else {
          res = self$extractFitted(dist, isCensored = TRUE, bootstraps = bootstraps, seed=seed)
          tmp = tmp %>% bind_rows(res$fittedModels)
          # nasty hack coming up....
          bstr = res$bootstraps %>% full_join(g, by = character())
          self$bootstraps = self$bootstraps %>% bind_rows(bstr)
        }
      }
      return(tmp)
    })
    
    self$fittedModels = distributionsDefinition(out)
    self$samples = NULL
    warning(unique(errors))
    invisible(self)
  },
  
  # fromFittedDistributions = function(fittedModelsDf) {
  #   self$fittedModels = distributionsDefinition(fittedModelsDf)
  #   browser()
  #   cols = colnames(self$fittedModels)
  #   cols = cols[!(cols %in% c("dist","param","mean","sd","lower","upper"))]
  #   self$grps = grps = sapply(cols,as.symbol)
  #   #if(identical(self$grps,NULL)) self$grps = list()
  #   self$fittedModels = self$fittedModels %>% group_by(!!!grps)
  #   self$createBootstraps(100)
  #   self$generateSamples(100)
  #   self$fitData = self$samples
  #   invisible(self)
  # },
  
  withSingleDistribution = function(dist, paramDf, bootstraps = 1000, epiestimMode = FALSE, ...) {
    dots = rlang::list2(...)
    distParams = names(DistributionFit$conversionFrom[[dist]])
    # is the distribution fully specified?
    if (all(distParams %in% paramDf$param) & !any(is.na(c(paramDf$mean,paramDf$sd,paramDf$lower,paramDf$upper)))) {
      # use native parameters. Assume a truncated normal.
      tmp_bootstraps = paramDf %>% group_by(param) %>% group_modify(function(d,g,..) {
        if(nrow(d) > 1) stop(paste0("param ",g$param," is not unique"))
        print(list(mean = d$mean, sd = d$sd, lower = d$lower, upper = d$upper))
        return(tibble(
          bootstrapNumber = 1:bootstraps,
          value = msm::rtnorm(n=bootstraps, mean = d$mean, sd = d$sd, lower = d$lower, upper = d$upper)
        ))
      })
      self$bootstraps = self$bootstraps %>% bind_rows(tmp_bootstraps %>% mutate(dist = dist) %>% mutate(...))
      self$fittedModels = self$fittedModels %>% bind_rows(paramDf %>% mutate(dist = dist) %>% mutate(...))
    } else if (all(c("mean","sd") %in% paramDf$param)) {
      # use moments to define distribution
      converted = DistributionFit$convertParameters(dist = dist, paramDf = paramDf, bootstraps = bootstraps, epiestimMode = epiestimMode, ...)
      self$bootstraps = self$bootstraps %>% bind_rows(converted$bootstraps %>% mutate(...))
      self$fittedModels = self$fittedModels %>% bind_rows(converted$fittedModels %>% mutate(...))
      # TODO: could formally check for intersection of previous and new groups, or that dots contains all the correct args.
      # TODO: rethink what is fit data here - we ought to be able to generate a single "mid market" sample: self$fitData = self$samples
      # I think we only need it for plotting. maybe better to disable that part of plotting if it is not present.
    } else {
      stop("paramDf does not define mean and sd, or other distribution parameters")
    }
    if (identical(self$grps,NULL))  self$grps = sapply(names(dots),as.symbol)
    self$censored = FALSE
    invisible(self)
    
  },
  
  fromBootstrappedDistributions = function(fittedDistributions, confint=c(0.025,0.975), ...) {
    self$setModels(unique(fittedDistributions$dist))
    cols = colnames(fittedDistributions)
    grps = sapply(cols[!(cols %in% c("bootstrapNumber","dist","param","value"))], as.symbol)
    self$bootstraps = fittedDistributions
    self$fittedModels = fittedDistributions %>% group_by(!!!grps,dist,param) %>% summarise(mean=mean(value),sd=sd(value),lower=quantile(value,confint[1]),upper=quantile(value,confint[2]))
    self$censored = FALSE
  },
  
  fromBootstrappedData = function(bootstrappedDf, valueExpr = value, truncate = TRUE,...) {
    dots = rlang::list2(...)
    valueExpr = ensym(valueExpr)
    errors = NULL
    grps = bootstrappedDf %>% groups()
    self$grps = grps
    self$groupedDf = bootstrappedDf
    #browser()
    self$bootstraps = bootstrappedDf %>% ensurer::ensure_that("bootstrapNumber" %in% colnames(.) ~ "bootstrapDf must have a bootstrapNumber column") %>%
      group_by(!!!grps, bootstrapNumber) %>% 
      group_modify(function(d,g,...) {
        tmpBootstrap = tibble(dist=NA_character_,param=NA_character_,value=NA_real_) %>% filter(!is.na(dist))
        for (m in self$models) {
          values = d %>% mutate(x = !!valueExpr) %>% pull(x)
          values = values+self$shifted
          if (truncate) {
            #if (any(values>m$support[2]|values<m$support[1])) message("Truncating data for ",m$name)
            values = values[values<=m$support[2] & values>=m$support[1]]
          }
          if(m$discrete) {
            values = round(values)
          }
          
          #control=list(trace=1, REPORT=1)
          arg = dots
          arg$data = values 
          arg$distr = m$name
          arg$start = m$start
          arg$lower=unlist(m$lower)
          #browser()
          if (length(m$fix.arg)>0) {arg$fix.arg=m$fix.arg}
          
          dist2 = suppressWarnings({
            tryCatch(do.call(fitdistrplus::fitdist, args=arg), error = function(e) {errors <<- c(errors,e$message);NULL})
          })
          
          if (!is.null(dist2)) {
            # build a bootstraps entry
            #browser()
            tmpBootstrap = tmpBootstrap %>% bind_rows(
              tibble(
                dist = dist2$distname,
                param = names(dist2$estimate),
                value = dist2$estimate,
                n = dist2$n,
                loglik = dist2$loglik,
                aic = dist2$aic,
                bic = dist2$bic
              )
            )
            if(length(m$fix.arg)>0) {
              tmpBootstrap = tmpBootstrap %>% bind_rows(
                tibble(
                  dist = dist2$distname, 
                  param = names(m$fix.arg),
                  value = unlist(m$fix.arg),
                  n = dist2$n,
                  loglik = dist2$loglik,
                  aic = dist2$aic,
                  bic = dist2$bic
                )
              )
            }
          }
        }
        return(tmpBootstrap)
      }) %>% group_by(!!!grps)
    
    self$fittedModels = self$bootstraps %>% group_by(!!!grps,dist,param) %>%
      summarise(
        mean = mean(value,na.rm = TRUE),
        sd = sd(value,na.rm = TRUE),
        lower = quantile(value, 0.025),
        upper = quantile(value, 0.975),
        n = sum(n),
        loglik = mean(loglik),
        aic = mean(aic),
        bic = mean(bic)
      ) %>% group_by(!!!grps)
    
    self$bootstraps = self$bootstraps %>% select(!!!grps,bootstrapNumber,dist,param,value)
    
    # self$groupedDf = bootstrappedDf %>% group_by(!!!grps) %>% summarise(
    #   !!paste0("Mean.",as_label(valueExpr) := mean(!!valueExpr)))
    self$fitData =  bootstrappedDf %>% mutate(
      value = !!valueExpr+self$shifted) %>%
      select(!!!grps,value)
    
    self$samples = bootstrappedDf %>% mutate(value = !!valueExpr) %>% 
      select(!!!grps, bootstrapNumber, value) %>% group_by(!!!grps, bootstrapNumber) %>% mutate(sampleNumber = row_number(), dist="empirical")
    self$censored = FALSE
    if (!identical(errors,NULL)) warning(unique(errors))
    invisible(self)
  },
  
  plot = function(xlim, binwidth = 1, summary=FALSE, pts=8, facet2d = TRUE) {
    
    #TODO: attribute fractions of survivalDf to different times based on left and right censoring.
    # This needs a review of structure fo whole thing
    # c(min(self$fitData$value),max(self$fitData$value))
    
    #browser()
    support = seq(xlim[1],xlim[2],length.out = 201)
    pdfs = self$calculateDensities(support+self$shifted)
    
    grps = self$grps
    
    
    
    
    p1 = ggplot()
    if (!identical(self$fitData,NULL)) {
      if(!self$censored) {
        p1 = p1 + geom_histogram(data=self$fitData, aes(x=value-self$shifted, y=..density..),fill=NA,colour = "grey50",binwidth = binwidth)
      } else {
        p1 = p1 + 
          geom_histogram(data=self$fitData, aes(x=left-self$shifted, y=..density..),fill="grey70",colour = NA,alpha=0.5,binwidth = binwidth) +
          geom_histogram(data=self$fitData, aes(x=right-self$shifted, y=..density..),fill="grey70",colour = NA,alpha=0.5,binwidth = binwidth)
      }
    }
    
    disc = unlist(map(DistributionFit$standardDistributions, ~ .$discrete))
    
    cont = pdfs %>% filter(dist %in% names(disc)[!disc])
    if(nrow(cont) > 0) {
    p1 = p1 +  geom_line(data = cont, aes(x=value-self$shifted,y=Mean.density,colour = dist), size = 1) +
      geom_ribbon(data = cont, aes(x=value-self$shifted, ymin=Quantile.0.025.density,ymax=Quantile.0.975.density,fill = dist),colour=NA,alpha=0.1)+
      geom_ribbon(data = cont, aes(x=value-self$shifted, ymin=Quantile.0.25.density,ymax=Quantile.0.75.density,fill = dist),colour=NA,alpha=0.15)
    }
    
    if(summary) {
      labels = self$printDistributionSummary() %>% group_by(!!!self$grps) %>% summarise(label = paste0(dist," ",param,": ",`Mean (95% CI)`,collapse = "\n"))
      p1 =p1 + geom_label(data=labels, aes(x = Inf, y = Inf, label = label), fill=NA, label.size=0, label.padding = unit(2, "lines"), hjust="inward", vjust="inward", size=pts/ggplot2:::.pt/(96/72))
    }
    
    discont = pdfs %>% filter(dist %in% names(disc)[disc])
    if(nrow(discont) > 0) {
    # p1 = p1 +  geom_point(data = discont, aes(x=value-self$shifted,y=Mean.density,colour = dist), size = 1,show.legend = FALSE) +
    #   geom_errorbar(data = discont, aes(x=value-self$shifted, ymin=Quantile.0.025.density,ymax=Quantile.0.975.density,colour = dist,fill=dist))+
    #   geom_errorbar(data = discont, aes(x=value-self$shifted, ymin=Quantile.0.25.density,ymax=Quantile.0.75.density,colour = dist))
    # }
    
    p1 = p1 +  geom_boxplot(data = discont, aes(x=value-self$shifted,
                                              y=Mean.density, 
                                              ymin=Quantile.0.025.density,
                                              lower=Quantile.0.25.density,
                                              middle=Quantile.0.5.density,
                                              upper=Quantile.0.75.density,
                                              ymax=Quantile.0.975.density,
                                              colour = dist, fill=dist, group=(paste0(dist,round(value-self$shifted)))),stat="identity",width=0.5, alpha = 0.25, show.legend = FALSE)
    }
      
    p1 = p1 + coord_cartesian(xlim = xlim)
    
    # https://stackoverflow.com/questions/45908120/add-shaded-standard-error-curves-to-geom-density-in-ggplot2
    
    grps = self$grps
    if (length(grps) == 2 & facet2d) {
      p1 = p1 + facet_grid(rows = vars(!!grps[[1]]), cols = vars(!!grps[[2]]), scales="free", shrink = TRUE) 
    } else if(length(grps) != 0) {
      p1 = p1 + facet_wrap(facets = grps, scales="free", shrink = TRUE)
    }
    return(p1)
  },
  
  extractFitted = function(distFit, isCensored, bootstraps = 100, seed=101) {
    if (!is.null(distFit)) {
      set.seed(seed)
      if (isCensored) {
        dist2 = suppressWarnings(fitdistrplus::bootdistcens(distFit,silent = TRUE,niter = bootstraps,ncpus = 3))
        #browser()
      } else {
        dist2 = suppressWarnings(fitdistrplus::bootdist(distFit,silent = TRUE,niter = bootstraps,ncpus = 3))
      }
    }
    #browser()
    m=DistributionFit$standardDistributions[[dist2$fitpart$distname]]
    fittedModels=tibble(
        n=distFit$n,
        aic=distFit$aic,
        bic=distFit$bic,
        loglik=distFit$loglik,
        dist = dist2$fitpart$distname,
        param = names(distFit$estimate),
        mean = unname(sapply(dist2$estim,mean, na.rm=TRUE)), #dist$estimate,
        sd = unname(sapply(dist2$estim,sd, na.rm=TRUE)), #dist$sd,
        lower = matrix(dist2$CI,ncol=3)[,2],
        upper = matrix(dist2$CI,ncol=3)[,3]
      )
    if(length(m$fix.arg)>0) {
      fittedModels = fittedModels %>% bind_rows(
        tibble(
          n=distFit$n,
          aic=distFit$aic,
          bic=distFit$bic,
          loglik=distFit$loglik,
          dist = dist2$fitpart$distname,
          param = names(m$fix.arg),
          mean = unlist(m$fix.arg),
          sd = 0,
          lower = unlist(m$fix.arg),
          upper = unlist(m$fix.arg)
        )
      )
    }
      
    return(
      list(
        fittedModels=distributionsDefinition(fittedModels),
        # TODO: a goodness of fit measure here would be reassuring, but how this woudl integrate with the rest of the 
        # code is a bit of a mystery. Problem is that some bootstraps may be a terrible fit but they are treated the same 
        # as those that are a good fit. Could I suppose dispose of a certain fraction of bootstraps that are a poor fit...
        bootstraps = dist2$estim %>% 
          dplyr::mutate(bootstrapNumber = row_number(), dist = dist2$fitpart$distname) %>% 
          tidyr::pivot_longer(names_to = "param", values_to = "value", cols=c(-bootstrapNumber,-dist))
        )
    )
  },
  
  filterModels = function(...) {
    clone = self$clone()
    clone$fittedModels = clone$fittedModels %>% group_by(!!!clone$grps) %>% filter(...)
    clone$bootstraps = clone$bootstraps %>% semi_join(clone$fittedModels, by=unlist(c(sapply(clone$grps,as_label),"dist")))
    if (!identical(clone$samples,NULL))
      clone$samples = clone$samples %>% semi_join(clone$fittedModels, by=unlist(c(sapply(clone$grps,as_label),"dist")))
    invisible(clone)
  },
  
  #' #' @description calculates a set of bootstrap parameter values
  #' #' @param bootstraps - number of bootstrap iterations 
  #' #' @return a list of randomly selected bootstraps from the fitted models.
  #' createBootstraps = function(bootstraps=100) {
  #'   grps = self$grps
  #'   distributionsDf = self$fittedModels %>% group_by(!!!grps,dist) %>% group_modify(function(d,g,...) {
  #'     #if(nrow(d) != 1) stop("A single covolution distribution must be defined for each group")
  #'     #func = paste0("p",d$distribution)
  #'     out = tibble(
  #'       #!!distributionVar = g %>% pull(distributionVar),
  #'       bootstrapNumber = 1:bootstraps
  #'     )
  #'     params = d %>% pull(param) %>% unique()
  #'     for (paramName in params) {
  #'       paramMean = d %>% filter(param == paramName) %>% pull(mean)
  #'       paramSd = d %>% filter(param == paramName) %>% pull(sd)
  #'       paramMin = d %>% filter(param == paramName) %>% pull(lower)
  #'       paramMax = d %>% filter(param == paramName) %>% pull(upper)
  #'       if(is.na(paramSd)) {
  #'         paramBoot = rep(paramMean, bootstraps)
  #'       } else {
  #'         stop("This is not working. It assumes parameter distributions are normal which they aren't")
  #'         # What I'll have to do is create bootstrap sample for all distributions when fitting
  #'         # this will retain the distributions
  #'         # the fittedModels will then be the 
  #'         paramBoot = msm::rtnorm(bootstraps, paramMean, paramSd, paramMin, paramMax)
  #'         
  #'       }
  #'       out = out %>% tibble::add_column(!!paramName := paramBoot)
  #'     }
  #'     return(out %>% tidyr::pivot_longer(cols = all_of(params), names_to = "param", values_to = "value"))
  #'   })
  #'   self$bootstraps = distributionsDf %>% group_by(!!!grps, bootstrapNumber)
  #'   invisible(self)
  #' },
  
  generateSamples = function(sampleExpr = 1000, seed = 101) {
    sampleExpr = enexpr(sampleExpr)
    if(identical(self$bootstraps, NULL)) self$bootstraps = self$fittedModels %>% dplyr::mutate(value = mean, bootstrapNumber=0)
    grps = self$grps
    set.seed(seed)
    bootstrapSamples = self$bootstraps %>% 
      mutate(tmp_samples = !!sampleExpr) %>% 
      group_by(!!!grps,dist,bootstrapNumber,tmp_samples) %>% 
      group_modify(function(d,g,...) {
        samples = g$tmp_samples
        func = paste0("r",g %>% pull(dist))
        params2 = d %>% select(param,value) %>% tibble::deframe()
        sampledDist = tibble(
          value = suppressWarnings(do.call(func, c(list(n=samples),params2))),
          sampleNumber = 1:samples
        )
        return(sampledDist %>% filter(!is.na(value)))
      })
    #browser()
    self$samples = bootstrapSamples %>% ungroup() %>% select(-tmp_samples) %>% dplyr::group_by(!!!grps) %>% dplyr::mutate(value = value-self$shifted)
    invisible(self)
  },
  
  #' @description calculates a set of pdfs from bootstrapped distributions
  #' @param support - the range of values of X for which to calculate the density.
  #' @return a list of randomly selected bootstraps from the fitted models.
  calculateDensities = function(x, summarise=TRUE) {
    grps = self$grps
    tmp = self$bootstraps %>% group_by(!!!grps,dist,bootstrapNumber) %>% group_modify(function(d,g,...) {
      func = paste0("d",g %>% pull(dist))
      if (DistributionFit$standardDistributions[[g$dist]]$discrete) {
        x = round(min(x)):round(max(x))
      }
      params2 = d %>% select(param,value) %>% tibble::deframe()
      pdfs = tibble(
        density = suppressWarnings(do.call(func, c(list(x=x),params2))),
        value = x
      )
      return(pdfs)
    })
    if (summarise) return(tmp %>% self$sevenNumbers(density))
    else return(tmp)
  },

  #' @description calculates a set of pdfs from bootstrapped distributions
  #' @param support - the range of values of X for which to calculate the density.
  #' @return a list of randomly selected bootstraps from the fitted models.
  calculateCumulativeDistributions = function(q, summarise=TRUE) {
    grps = self$grps
    tmp = self$bootstraps %>% group_by(!!!grps,dist,bootstrapNumber) %>% group_modify(function(d,g,...) {
      func = paste0("p",g %>% pull(dist))
      if (DistributionFit$standardDistributions[[g$dist]]$discrete) {
        q = round(q)
      }
      params2 = d %>% select(param,value) %>% tibble::deframe()
      pdfs = tibble(
        cumulative = suppressWarnings(do.call(func, c(list(q=q),params2))),
        value = q
      )
      return(pdfs)
    })
    if (summarise) return(tmp %>% self$sevenNumbers(cumulative))
    else return(tmp)
  },
  
  #' @description calculates a set of quantiles from bootstrapped distributions
  #' @param support - the range of values of X for which to calculate the density.
  #' @return a list of randomly selected bootstraps from the fitted models.
  calculateQuantiles = function(p, summarise=TRUE) {
    grps = self$grps
    tmp = self$bootstraps %>% group_by(!!!grps,dist,bootstrapNumber) %>% group_modify(function(d,g,...) {
      func = paste0("q",g %>% pull(dist))
      params2 = d %>% select(param,value) %>% tibble::deframe()
      pdfs = tibble(
        quantile = suppressWarnings(do.call(func, c(list(p=p),params2))),
        probability = p
      )
      return(pdfs)
    })
    if (summarise) return(tmp %>% self$sevenNumbers(quantile))
    else return(tmp)
  },
  
  discreteProbabilities = function(q, summarise=TRUE) {
    tmp = self$calculateCumulativeDistributions(q, summarise=FALSE) %>% ungroup() %>% 
      group_by(across(c(-value,-cumulative))) %>% 
      arrange(value) %>%
      mutate(discreteProbability = cumulative-lag(cumulative, default = 0)) %>%
      select(-cumulative)
    if (summarise) return(tmp %>% self$sevenNumbers(discreteProbability))
    else return(tmp)
  },
  
  discreteSurvival = function(q, summarise=TRUE) {
    tmp = self$calculateCumulativeDistributions(q, summarise=FALSE) %>% ungroup() %>% 
      group_by(across(c(-value,-cumulative))) %>% 
      arrange(value) %>%
      mutate(survival = 1-cumulative) %>%
      select(-cumulative)
    if (summarise) return(tmp %>% self$sevenNumbers(survival))
    else return(tmp)
  },
  
  sampledProbabilities = function(q, summarise=TRUE) {
    #TODO: offset? shift?
    if(identical(self$samples, NULL)) self$generateSamples()
    self$samples %>% mutate(
      band = cut(value,breaks = c(-Inf,q),labels = q,ordered_result = TRUE)
    ) %>% mutate(
      value = q[as.integer(band)]
    ) %>% group_by(!!!self$grps,value) %>% summarise(
      sampledProbability = n()/nrow(self$samples)
    ) %>% tidyr::complete(value,sampledProbabilities)
  },
  
  sevenNumbers = function(groupedDf, colVar) {
    grps = self$grps
    colVar = ensym(colVar)
    name = function(pref) paste0(pref,".",as_label(colVar))
    groupedDf %>% ungroup() %>% group_by(across(c(-bootstrapNumber,-(!!colVar)))) %>% summarise(
      !!name("Mean") := mean(!!colVar ,na.rm=TRUE),
      !!name("Sd") := sd(!!colVar ,na.rm=TRUE),
      !!name("Quantile.0.025") := quantile(!!colVar, probs = 0.025, na.rm=TRUE),
      !!name("Quantile.0.25") := quantile(!!colVar, probs = 0.25, na.rm=TRUE),
      !!name("Quantile.0.5") := quantile(!!colVar, probs = 0.5, na.rm=TRUE),
      !!name("Quantile.0.75") := quantile(!!colVar, probs = 0.75, na.rm=TRUE),
      !!name("Quantile.0.975") := quantile(!!colVar, probs = 0.975, na.rm=TRUE),
    )
    #TODO: could make this configurable with purrr nested list
    #levels = c(0.025,0.25,0.5,0.75,0.975)
    #quan = quantile(sample, levels) %>% tibble::enframe() %>% mutate(name = paste0("Quantile."),levels) %>% pivot_wider(names_from="name", values_from="value")
    #followed by an unnest or flatten
  },
  
  printDistributionSummary = function(confint = c(0.025,0.975)) {
    grps = self$grps
    tmp = DistributionFit$unconvertParameters(bootstrapsDf = self$bootstraps %>% group_by(!!!grps), confint=confint)
    ci = floor((confint[2]-confint[1])*100)
    #TODO: we don't technically know if the self$fittedModels parameters have the same confidence intervals.
    # could recalculate them from self$bootstraps
    tmp = tmp %>% dplyr::mutate( #self$fittedModels %>% bind_rows(tmp) %>% mutate(
      `Distribution` = purrr::map_chr(dist, ~ self$models[[.x]][["print"]]),
      shift = self$shifted,
      #!!(sprintf("Mean \U00B1 SD (%1d%% CI)",ci)) := sprintf("%1.2f \U00B1 %1.2f (%1.2f; %1.2f)",mean,sd,lower,upper)
      !!(sprintf("Mean (%1d%% CI)",ci)) := sprintf("%1.2f (%1.2f; %1.2f)",mean,lower,upper)
    ) %>% mutate(
      `Distribution` = ifelse(shift>0, paste0(`Distribution`, "(shifted ",shift,")"), `Distribution`)
    )
    return(tmp)
  },
  
  printDistributionDetail = function(confint = c(0.025,0.975)) {
    grps = self$grps
    tmp = DistributionFit$unconvertParameters(bootstrapsDf = self$bootstraps %>% group_by(!!!grps), confint=confint)
    ci = floor((confint[2]-confint[1])*100)
    #TODO: we don't technically know if the self$fittedModels parameters have the same confidence intervals.
    # could recalculate them from self$bootstraps
    tmp = self$fittedModels %>% bind_rows(tmp) %>% mutate(
      `Distribution` = purrr::map_chr(dist, ~ self$models[[.x]][["print"]]),
      shift = self$shifted,
      !!(sprintf("Mean \U00B1 SD (%1d%% CI)",ci)) := sprintf("%1.2f \U00B1 %1.2f (%1.2f; %1.2f)",mean,sd,lower,upper)
    ) %>% mutate(
      `Distribution` = ifelse(shift>0, paste0(`Distribution`, "(shifted ",shift,")"), `Distribution`)
    ) %>% group_by(dist,.add=TRUE)
    
    if ("aic" %in% names(tmp)) tmp = tmp %>% mutate(aic= first(na.omit(aic)))
    if ("n" %in% names(tmp)) tmp = tmp %>% mutate(n = first(na.omit(n)))
    if ("bic" %in% names(tmp)) tmp = tmp %>% mutate(bic= first(na.omit(bic)))
    if ("loglik" %in% names(tmp)) tmp = tmp %>% mutate(loglik= first(na.omit(loglik)))
    return(tmp)
  },
  
  ## TODO: integrate following functions: ----
  
 
  
  
  
  #' @description applies a set of parameterised convolution functions groupwise to input data
  #' 
  
  #' @param groupedDf - an optionally grouped dataframe, containing at dateVar, and a valueVar to be colvolved
  #' @param distributionsDf - a dataframe containing the same grouping columns as groupedDf plus "distribution", and relevant distribution parameter columns
  #' @param dateVar - 
  #' @param days - 
  #' @param timepoints - the times to 
  #' @param padLeft - what can we assume about the run in to the current values? default NA.
  #' @return a list of matrices
  
  tsParameterizedConvolution = function(groupedDf, distributionsDf, outputVar = "output", valueVar = "value", dateVar="date", distributionVar = "dist", paramNameVar = "param", paramValueVar = "paramValue", days = 30, timepoints = 0:days, padLeft = NA_real_, padRight = NA_real_) {
    grps = groupedDf %>% groups()
    distributionVar = ensym(distributionVar)
    paramNameVar = ensym(paramNameVar)
    paramValueVar = ensym(paramValueVar)
    
    discreteDistDf = distributionsDf %>% group_by(!!!grps,!!distributionVar) %>% group_modify(function(d,g,...) {
      #if(nrow(d) != 1) stop("A single covolution distribution must be defined for each group in groupedDf")
      func = paste0("p",g %>% pull(!!distributionVar))
      
      params2 = paramList = d %>% select(!!paramNameVar,!!paramValueVar) %>% tibble::deframe()
      
      #funcCDF = paste0("p",d$distribution)
      #params = formals(func)
      #params2 = d %>% select(any_of(names(params))) %>% as.list()
      start = timepoints[1:(length(timepoints)-1)]
      end = timepoints[2:(length(timepoints))]
      discreteDist = tibble(
        start = start,
        end = end,
        prob = do.call(func, c(q=list(end),params2))-do.call(func, c(q=list(start),params2)),
        surv = 1-do.call(func, c(q=list(start),params2))
      )
      return(discreteDist)
    })
    
    tsDiscreteCovolution(groupedDf, discreteDistDf, {{outputVar}}, {{valueVar}},{{dateVar}},padLeft, padRight)
    
  },
  
  #' @description applies a set of parameterised convolution functions groupwise to input data
  #' 
  
  #' @param groupedDf - an optionally grouped dataframe, containing at dateVar, and a valueVar to be colvolved
  #' @param discreteDistDf - a dataframe containing the same grouping columns as groupedDf plus "start", "end" and "prob" columns
  #' @param dateVar - 
  #' @param padLeft - what can we assume about the run in to the current values? default NA.
  #' @param padRight - what can we assume about the run in to the current values? default NA.
  #' @return a list of matrices
  
  tsDiscreteConvolution = function(groupedDf, discreteDistDf, outputVar = "output", valueVar="value", dateVar="date", pExpr="prob", padLeft=NA_real_, padRight=NA_real_ ) {
    grps = groupedDf %>% groups()
    valueVar = ensym(valueVar)
    dateVar = ensym(dateVar)
    outputVar = ensym(outputVar)
    pExpr = ensym(pExpr)
    joinCols = c(sapply(grps, as_label),"tmp_join")
    finalJoinCols = c(sapply(grps, as_label),as_label(dateVar))
    
    dateRanges = groupedDf %>% summarise(min_date = min(!!dateVar), max_date = max(!!dateVar)) %>% mutate(tmp_join=1)
    
    combinations = discreteDistDf %>% select(!!!grps,start) %>% mutate(tmp_join=1) %>% inner_join(dateRanges, by=joinCols)
    
    padLeftDf = combinations %>% mutate(tmp_value=padLeft, tmp_date=min_date-start-1) %>% select(-start)
    padRightDf = combinations %>% mutate(tmp_value=padRight, tmp_date=max_date+start+1) %>% select(-start)
    
    tmp = groupedDf %>% select(!!!grps,tmp_value=!!valueVar, tmp_date=!!dateVar) %>% group_by(!!!grps) %>%
      mutate(tmp_join=1, min_date = min(tmp_date), max_date = max(tmp_date))
    
    tmp = bind_rows(padLeftDf,tmp,padRightDf)
    
    tmp2 = tmp %>% 
      inner_join(discreteDistDf %>% mutate(tmp_join=1), by=joinCols)
    
    tmp3 = tmp2 %>% mutate(eff_date = tmp_date+start, eff_value = tmp_value*!!pExpr) %>% select(-tmp_join) %>%
      filter(eff_date <= max_date & eff_date >= min_date)
    
    tmp4 = tmp3 %>% group_by(!!!grps,eff_date) %>% summarise(!!outputVar := sum(eff_value)) %>% mutate(!!dateVar := as.Date(eff_date,"1970-01-01")) %>% select(-eff_date)
    
    return(groupedDf %>% inner_join(tmp4, by=finalJoinCols))
  },
  
 
  
  
  #' @description executes a convolution using bootstrapped parameterized distributions
  #' 
  
  #' @param distributionDistDf - a grouped data frame containing the same columns as those grouped in groupedDf, plus "distribution" and  columns for "<parameter>_mean","<parameter>_sd","<parameter>_min","<parameter>_max" e.g. shape_mean, shape_sd, shape_min, shape_max, rate_mean, rate_sd, rate_min, rate_max
  #' @param bootstraps - number of bootstrap iterations 
  #' @return the full bootstrap result (i.e. not summarized) which can be futher convoluted (asd long as the number of bootstrap iterations are kept the same).
  
  tsBootstrapConvolution = function(groupedDf, distributionDistDf, bootstraps=100, outputVar = "output", valueVar="value", dateVar="date", days = 30, timepoints = 0:days, padLeft=NA_real_, padRight=NA_real_) {
    grps = groupedDf %>% groups()
    tmp = groupedDf %>% crossing(tibble(bootstrapNumber = 1:bootstraps)) %>% group_by(!!!grps,bootstrapNumber)
    distributionDistDf = distributionDistDf %>% group_by(!!!grps)
    out = tmp %>% tsParameterizedConvolution(bootstrapDistributions(distributionDistDf, bootstraps = bootstraps), 
                                             outputVar = {{outputVar}}, valueVar = {{valueVar}}, dateVar = {{dateVar}}, days = days, timepoints = timepoints, padLeft = padLeft, padRight = padRight)
    out = out %>% ungroup() %>% group_by(!!!grps)
    return(out)
  },
  
  #' @description summarise the result of a bootstrapped convolution using parameterized distributions
  #' 
  
  #' @param distributionDistDf - a grouped data frame containing the same columns as those grouped in groupedDf, plus "distribution" and  columns for "<parameter>_mean","<parameter>_sd","<parameter>_min","<parameter>_max" e.g. shape_mean, shape_sd, shape_min, shape_max, rate_mean, rate_sd, rate_min, rate_max
  #' @param bootstraps - number of bootstrap iterations 
  #' @return the full bootstrap result (i.e. not summarized) which can be futher convoluted (asd long as the number of bootstrap iterations are kept the same).
  
  summariseBootstrap = function(groupedDf, outputVars = vars(output), dateVar = "date") {
    grps = groupedDf %>% groups()
    dateVar = ensym(dateVar)
    joinCols = c(sapply(grps,as_label),as_label(dateVar))
    out = groupedDf %>% select(!!!grps,!!dateVar) %>% distinct()
    p <- c(0.2, 0.5, 0.8)
    for(outputVar in outputVars) {
      
      tmp = groupedDf %>% group_by(!!!grps,!!dateVar) %>% summarise(
        !!(paste0("CI.0.025(",as_label(outputVar),")")) := quantile(!!outputVar,p=0.025),
        !!(paste0("CI.0.1(",as_label(outputVar),")")) := quantile(!!outputVar,p=0.1),
        !!(paste0("Median(",as_label(outputVar),")")) := quantile(!!outputVar,p=0.5),
        !!(paste0("CI.0.9(",as_label(outputVar),")")) := quantile(!!outputVar,p=0.9),
        !!(paste0("CI.0.975(",as_label(outputVar),")")) := quantile(!!outputVar,p=0.975),
      )
      
      out = out %>% inner_join(tmp, by=joinCols)
    }
    return(out)
  }
))

#### Standard distributions ----

DistributionFit$standardDistributions = list(
  weibull = list(name = "weibull", print = "weibull", start = list(shape =1 , scale=10), fix.arg = list(), lower=list(shape=0,scale=0), support=c(0,Inf), discrete=FALSE),
  gamma = list(name = "gamma", print = "gamma", start = list(shape =1 , rate=0.1), fix.arg = list(), lower=list(shape=0,rate=0), support=c(0,Inf), discrete=FALSE),
  lnorm = list(name = "lnorm", print = "log-normal", start = list(meanlog=1 , sdlog=1), fix.arg = list(), lower=list(meanlog=-Inf, sdlog=-Inf), support=c(0,Inf), discrete=FALSE),
  exp = list(name = "exp", print = "exponential", start = list(rate=0.5), fix.arg = list(), lower=list(rate=0.000001), support=c(.Machine$double.xmin, Inf), discrete=FALSE),
  norm = list(name = "norm", print = "normal", start = list(mean = 1, sd = 0.1), fix.arg = list(), lower=list(sd=0), support=c(-Inf,Inf), discrete=FALSE),
  tnorm = list(name = "tnorm", print = "truncated normal", start = list(mean = 1, sd = 0.1), fix.arg = list(lower=0,upper=Inf), support=c(0,Inf), lower=list(mean=0,sd=0.001), discrete=FALSE),
  nbinom = list(name = "nbinom", print = "negative binomial", start = list(size = 1, prob = 0.1), fix.arg = list(), lower=list(size=.Machine$double.xmin, prob=0), support=c(0,Inf), discrete=TRUE),
  pois = list(name = "pois", print = "poisson", start = list(lambda = 1), fix.arg = list(), lower=list(lambda=0), support=c(0,Inf), discrete=TRUE)
)

DistributionFit$conversionFrom = list(
  gamma = list(
    scale = function(mean,sd) sd^2/mean,
    shape = function(mean,sd) mean^2/sd^2
    #rate = function(mean,sd) sd^2/mean^2
  ),
  lnorm = list(
    meanlog = function(mean,sd) log(mean/sqrt(1+(sd^2)/(mean^2))),
    sdlog = function(mean,sd) sqrt(log(1+(sd^2)/(mean^2)))
  ),
  norm = list(
    mean = function(mean,sd) mean,
    sd = function(mean,sd) sd
  ),
  tnorm = list(
    mean = function(mean,sd) mean,
    sd = function(mean,sd) sd,
    lower = function(mean,sd) 0,
    upper = function(mean,sd) Inf
  ),
  weibull = list(
    shape = function(mean,sd) (sd/mean)^-1.086,
    scale = function(mean,sd) mean/gamma(1+1/((sd/mean)^-1.086))
  ),
  exp = list(
    rate = function(mean,sd) 1/mean
  ),
  nbinom = list(
    size = function(mean,sd) (mean^2)/(sd^2-mean),
    prob = function(mean,sd) 1-mean/(sd^2)
  ),
  pois = list(
    lambda = function(mean,sd) mean
  )
)

DistributionFit$conversionTo = list(
  gamma = list(
    mean = function(shape,scale = 1/rate, rate=1/scale) shape*scale,
    sd = function(shape,scale = 1/rate, rate=1/scale) sqrt(shape*(scale^2))
  ),
  lnorm = list(
    mean = function(meanlog,sdlog) exp(meanlog+(sdlog^2)/2),
    sd = function(meanlog,sdlog) sqrt((exp(sdlog^2)-1)*exp(2*meanlog+sdlog^2))
  ),
  norm = list(
    mean = function(mean,sd) mean,
    sd = function(mean,sd) sd
  ),
  tnorm = list(
    mean = function(mean,sd,lower,upper) mean,
    sd = function(mean,sd,lower,upper) sd
  ),
  weibull = list(
    mean = function(scale, shape) scale*gamma(1+1/shape),
    sd = function(scale, shape) sqrt(scale^2 * (gamma(1+2/shape) - (gamma(1+1/shape))^2))
  ),
  exp = list(
    mean = function(rate) 1/rate,
    sd = function(rate) 1/(rate^2)
  ),
  nbinom = list(
    mean = function(size,prob) size*(1-prob)/prob,
    sd = function(size,prob) sqrt(size*(1-prob)/(prob^2))
  ),
  pois = list(
    mean = function(lambda) lambda,
    sd = function(lambda) lambda
  )
)


DistributionFit$unconvertParameters = function(
  bootstrapsDf,
  confint = c(0.025,0.975),
  ...
) {
  grps = bootstrapsDf %>% groups()
  
  tmp = bootstrapsDf %>% group_by(!!!grps,dist) %>% group_modify(function(d,g,...) {
    paramList = d %>% dplyr::select(param,value) %>% tidyr::nest(data=c(value)) %>% tibble::deframe() %>% purrr::map(~ .$value)
    funcs = DistributionFit$conversionTo[[g$dist]]
    conv = tibble(
      param = "mean",
      value = suppressWarnings(do.call(funcs$mean, paramList)),
    ) %>% bind_rows(tibble(
      param = "sd",
      value = suppressWarnings(do.call(funcs$sd, paramList)),
    ))
  })
  tmp = tmp %>% group_by(!!!grps,dist, param) %>% summarise(
    lower = quantile(value, probs = confint[1],na.rm = TRUE),
    upper = quantile(value, probs = confint[2],na.rm = TRUE),
    mean = mean(value,na.rm = TRUE),
    sd = sd(value,na.rm = TRUE)
  )
  return(tmp)
}

DistributionFit$convertParameters = function(
    paramDf,
    N = NA_integer_,
    dist="gamma",
    confint = c(0.025,0.975),
    bootstraps = 1000, 
    epiestimMode = FALSE,
    ...
  ) {
  
    meanOfMean = parameterDefinition(paramDf) %>% filter(param=="mean") %>% pull(mean)
    sdOfMean = parameterDefinition(paramDf) %>% filter(param=="mean") %>% pull(sd)
    lowerOfMean = parameterDefinition(paramDf) %>% filter(param=="mean") %>% pull(lower)
    upperOfMean = parameterDefinition(paramDf) %>% filter(param=="mean") %>% pull(upper)
    meanOfSd = parameterDefinition(paramDf) %>% filter(param=="sd") %>% pull(mean)
    sdOfSd = parameterDefinition(paramDf) %>% filter(param=="sd") %>% pull(sd)
    lowerOfSd = parameterDefinition(paramDf) %>% filter(param=="sd") %>% pull(lower)
    upperOfSd = parameterDefinition(paramDf) %>% filter(param=="sd") %>% pull(upper)
  
    invalid = function(x) {is.null(x) | any(is.na(x))}
  
  if (invalid(c(meanOfMean,meanOfSd))) stop("must supply non NA values for meanOfMean and meanOfSd")
  conversion = DistributionFit$conversionFrom[[dist]]
  paramSamples = NULL
  if (invalid(sdOfMean)) sdOfMean = (upperOfMean-lowerOfMean)/(qnorm(confint[2],0,1)-qnorm(confint[1],0,1))
  
  if (!invalid(sdOfMean) & invalid(sdOfSd)) {
    if (invalid(c(lowerOfSd, upperOfSd))) {
      #stop("must supply non NA values for: lowerOfSd and upperOfSd or sdOfSd")
      # browser()
      sdOfSd = 0
    } else {
      fit.sd.gamma = suppressWarnings(nls(y ~ qgamma(x, shape=tmp_shape, scale=meanOfSd/tmp_shape), data = tibble( x=c(confint[1],confint[2]), y=c(lowerOfSd, upperOfSd) )))
      sdShape = summary(fit.sd.gamma)$parameters[["tmp_shape",1]]
      sdScale = meanOfSd/sdShape
      sdOfSd = sqrt(sdShape*sdScale^2)
    }
    #TODO: we could have fixed the scale to 2 and used the moments to get the shape from the mean... This wouldn't use the lower and upper at all.
  }
  
  if ((invalid(sdOfMean) | sdOfMean == 0) & (invalid(sdOfSd) | sdOfSd == 0)) {
    
    bootstraps = 1
    boot_mean = meanOfMean
    boot_sd = meanOfSd
    
  } else {
  
    if (invalid(sdOfMean)) stop("either sdOfMean or upperOfMean and lowerOfMean must be specified")
    if (invalid(sdOfSd)) stop("either sdOfSd or upperOfSd and lowerOfSd must be specified")
    
    # central limit - means are normally
    boot_mean = msm::rtnorm(bootstraps,mean = meanOfMean,sd = sdOfMean,lower = lowerOfMean, upper=upperOfMean) 
    if (epiestimMode) {
      boot_sd = msm::rtnorm(bootstraps,mean = meanOfSd,sd = sdOfSd,lower = lowerOfSd, upper=upperOfSd) 
    } else {
      if (is.na(N)) {
        if (sdOfSd == 0) {
          #browser()
          boot_sd = meanOfSd
        } else {
          sdShape = meanOfSd^2/sdOfSd^2
          sdScale = sdOfSd^2/meanOfSd # this should maybe hard limited to 2 - as SD is chisq distributed and chisq is gamma with scale 2 (or rate 0.5).
          boot_sd = rgamma(bootstraps,shape = sdShape,scale = sdScale)
        }
      } else {
        boot_sd = meanOfSd*sqrt(rchisq(bootstraps, N-1)/(N-1))
      }
      if (!invalid(c(lowerOfSd, upperOfSd))) {
        boot_sd = boot_sd %>% scales::squish(range = c(lowerOfSd,upperOfSd))
      }
    }
    
  }
  # browser()
  for(paramName in names(conversion)) {
    paramSamples = paramSamples %>% bind_rows(tibble(
      dist = dist,
      bootstrapNumber = 1:bootstraps,
      param = paramName,
      value = conversion[[paramName]](boot_mean, boot_sd)
    ))  
  }
  #TODO: N.b. maybe paramSamples is the bootstrappedDistributions as generated by createBootstraps
  paramDfNew = paramSamples %>% group_by(dist,param) %>% summarise(
    n = N,
    mean = mean(value, na.rm = TRUE),
    sd = sd(value, na.rm = TRUE),
    lower = quantile(value,confint[1], na.rm = TRUE),
    upper = quantile(value,confint[2], na.rm = TRUE)
  )
  return(list(
    bootstraps = paramSamples,
    fittedModels = paramDfNew
  ))
}

#TODO: reversing thsi is a question of grabbing bootstraps and converting them back into mean and sd
# paramList = dfit$bootstraps %>% select(param,value) %>% nest(value) %>% tibble::deframe() %>% map(~ .$value)
# gives you a named list of vectors
# funcs = DistributionFit$conversionTo[["dist"]]
# mean = suppressWarnings(do.call(funcs$mean, paramList)),
# sd = suppressWarnings(do.call(funcs$sd, paramList)),
# which can be used to generate mean and sd 
terminological/uk-covid-datatools documentation built on June 24, 2021, 8:16 p.m.