knitr::opts_chunk$set(echo=FALSE, warning=FALSE, fig.width=8, fig.height=4.5)
set.seed(0)

import::from(mgcv, gam)
import::from(plotly, plot_ly, ggplotly, add_lines, add_markers, config, add_ribbons)
import::from(ggplot2, ggplot, geom_line, geom_density, geom_point, geom_ribbon, geom_segment, coord_cartesian, theme_minimal, theme, element_text, scale_color_manual, scale_fill_manual, scale_linetype_manual, aes, labs, guides, guide_legend)
import::from(dplyr, group_by, ungroup, do, select, mutate, filter, rename)
import::from(plyr, ldply, llply, compact)
import::from(magrittr, `%>%`, `%<>%`)
import::from(RColorBrewer, brewer.pal)
import::from(htmlwidgets, saveWidget)
import::from(latex2exp, TeX)

Agenda

Intro to splines

# Plot piecewise continuous function over range. Range is a vector including points of discontinuity; for example range = c(1, 2, 3) means range is (1, 3) with a discontinuity at 2. 
add_func = function(range, dx, f, name=NA, breaks=TRUE) {
  data = seq(1:(length(range) - 1)) %>%
    ldply(function(idx) {
      x0 = range[idx]
      x1 = range[idx + 1]
      x = seq(x0, x1, by=dx)
      y = rep(f[[idx]](x), length.out=length(x)) # This rep here to catch the case when f(x) returns a single scalar
      x = c(x, x1)
      y = c(y, NA)
      data.frame(x=x, y=y)
    })

  plot = list(
    geom_line(data=data, aes(x=x, y=y, color=name), show.legend=TRUE)
  )
  if (breaks) {
    plot %<>% c(list(
      geom_point(data=data %>% filter(x %in% range), aes(x=x, y=y, color=name))
    ))
  }
  plot
}

base_plot = function(data=NULL) {
  palette1 = rev(brewer.pal(5, "YlGnBu"))[2:4]
  palette2 = rev(brewer.pal(9, "YlGnBu"))[2:6]

  aesthetics = list(
    "true" = list(
      label = "True",
      color = "#FF0000",
      linetype = "12"
    ),
    "obs" = list(
      label = "Observed",
      color = "#C0C0C0"
    ),
    "est" = list(
      label = "Estimated",
      color = "#000000",
      linetype = "solid"
    ),
    "est1" = list(
      label = "Estimate #1",
      color = palette2[1],
      linetype = "solid"
    ),
    "est2" = list(
      label = "Estimate #2",
      color = palette2[2],
      linetype = "solid"
    ),
    "est3" = list(
      label = "Estimate #3",
      color = palette2[3],
      linetype = "solid"
    ),
    "est4" = list(
      label = "Estimate #4",
      color = palette2[4],
      linetype = "solid"
    ),
    "est5" = list(
      label = "Estimate #5",
      color = palette2[5],
      linetype = "solid"
    ),
    "est-95-cl" = list(
      label = "Estimated (95% CL)",
      fill = palette1[1]
    ),
    "est-dist" = list(
      label = "Estimated distribution",
      fill = palette1[2]
    ),
    "y" = list(
      label = TeX("f(x)"),
      color = palette1[1],
      linetype = "solid"
    ),
    "y1" = list(
      label = TeX("f'(x)"),
      color = palette1[2],
      linetype = "solid"
    ),
    "y2" = list(
      label = TeX("f''(x)"),
      color = palette1[3],
      linetype = "solid"
    ),
    "yhat" = list(
      label = TeX("\\hat{f}(x)"),
      color = palette1[1],
      linetype = "solid"
    ),
    "yhat1" = list(
      label = TeX("\\hat{f}'(x)"),
      color = palette1[2],
      linetype = "solid"
    ),
    "yhat2" = list(
      label = TeX("\\hat{f}''(x)"),
      color = palette1[3],
      linetype = "solid"
    )
  )

  makeScale = function(scale, aesName) {
    aesValid = llply(aesthetics, function(aes) {
      if (!is.null(aes[[aesName]])) {
        aes
      }
    }) %>% compact()
    aesValues = llply(aesValid, function(aes) aes[[aesName]])
    aesLabels = llply(aesValid, function(aes) aes$label)

    scale(
      name=NULL,
      breaks=names(aesValues),
      values=unlist(aesValues),
      labels=aesLabels
    )
  }

  ggplot(data=data) + 
    makeScale(scale_color_manual, "color") +
    makeScale(scale_fill_manual, "fill") +
    makeScale(scale_linetype_manual, "linetype") +
    theme_minimal() + 
    theme(
      legend.text=element_text(size=12),
      legend.position="bottom"
    ) +
    labs(x=NULL, y=NULL)
}

show_plot = function(plot) {
  plot
}

Splines: what are they?

A spline is a piecewise polynomial on $[a, b] \subset \mathbb{R}$. The points at which the polynomial segments are joined are called knots.

(base_plot() +
  add_func(c(0, 1, 2, 3), 0.001, c(
    function(x) x, 
    function(x) (17/16 - (x-1.25)**2),
    function(x) 0.25
  ), name="y")
) %>% show_plot()

If $P_i$ are polynomials of degree ≤ m, then the spline S has degree m and order m+1.

A spline is continuous in all derivatives everywhere except at its knots. At its knots a spline can have continuity order ranging from -1 (no continuity) through m-1.

Spline degree 1, continuity -1

f = c(
  function(x) 1.25 - x, 
  function(x) 1 - x / 2
)

(base_plot() + 
  add_func(c(0, 1, 2), 0.001, f, name="y")
  # add_func(c(0, 1, 2), 0.001, f2, name="dy/dx") %>%
) %>% show_plot()

Spline degree 1, continuity 0

f = c(
  function(x) 1.5 - x, 
  function(x) 1 - x/2
)
df = c(
  function(x) -1, 
  function(x) -1/2
)

(base_plot() +
  add_func(c(0, 1, 2), 0.001, f, name="y") +
  add_func(c(0, 1, 2), 0.001, df, name="y1", breaks=FALSE)
) %>% show_plot()

Spline degree = 2, continuity -1

f = c(
  function(x) x**2, 
  function(x) 2 - (x - 2)**2 / 2
)

(base_plot() +
  add_func(c(0, 1, 2), 0.001, f, name="y")
) %>% show_plot()

Spline degree 2, continuity = 0

f = c(
  function(x) 0.5 + x**2, 
  function(x) 2 - (x - 2)**2 / 2
)
df = c(
  function(x) x*2, 
  function(x) 2-x
)

(base_plot() +
  add_func(c(0, 1, 2), 0.001, f, name="y")+
  add_func(c(0, 1, 2), 0.001, df, name="y1", breaks=FALSE)
) %>% show_plot()

Spline degree 2, continuity 1

f = c(
  function(x) x**2,
  function(x) 2 - (x - 2)**2
)
df = c(
  function(x) 2 * x,
  function(x) -2 * (x - 2)
)
d2f = c(
  function(x) 2, 
  function(x) -2
)

(base_plot() +
  add_func(c(0, 1, 2), 0.001, f, name="y") +
  add_func(c(0, 1, 2), 0.001, df, name="y1", breaks=FALSE) +
  add_func(c(0, 1, 2), 0.001, d2f, name="y2", breaks=FALSE)
) %>% show_plot()

Spline modeling

Spline modeling

Goal: fit a spline to observations. Cubic splines are commonly used.

In a simple spline model, quality of fit is measured by mean square error (MSE).

deriv_name = Vectorize(function(d) {
  if (d == 0) {
    "yhat"
  } else {
    sprintf("yhat%d", d)
  }
})

add_psplines = function (x, y, k, m.fit, m.penalty, cyclic=FALSE, show.knots=TRUE, deriv=0) {
  data.obs = data.frame(x=x, y=y)
  if (cyclic) {
    model = gam(y ~ s(x, k=k, bs="cp", m=c(m.fit, m.penalty)), family=gaussian, data=data.obs)
  } else {
    model = gam(y ~ s(x, k=k, bs="ps", m=c(m.fit, m.penalty)), family=gaussian, data=data.obs)
  }

  data.pred = data.obs %>% 
    select(x) %>%
    mutate(y = predict(model, data.frame(x=x)), d=0)

  if (deriv > 0) {
    for (idx in 1:deriv) {
      data.pred %<>%
        filter(d == idx - 1) %>%
        do((function(df) {
          dy = diff(df$y) / diff(df$x)
          if (cyclic) {
            dy = (c(tail(dy, 1), dy) + c(dy, head(dy, 1))) / 2
          } else {
            dy = (c(head(dy, 1), dy) + c(dy, tail(dy, 1))) / 2
          }
          data.frame(x=df$x, y=dy, d=idx)
        })(.)) %>% 
        rbind(data.pred)
    }
  }

  knots = model$smooth[[1]]$knots
  if (!cyclic) {
    knots = knots[(m.fit+2):(length(knots)-m.fit-1)]
  }

  knots %<>% data.frame(x=.)
  knots = model %>% 
    predict(knots) %>% 
    data.frame(y=.) %>%
    cbind(knots)

  c(
    geom_point(data=data.obs, aes(x=x, y=y, color="obs")),
    geom_line(data=data.pred, aes(x=x, y=y, group=d, color=deriv_name(d))),
    geom_point(data=knots, aes(x=x, y=y, color=deriv_name(0)))
  )
}

Linear spline model, 3 knots

set.seed(0)
f = function(x) (1 - (2 * x - 1/2)**2)/2
x = seq(0, 1, by=0.01)
y = rnorm(length(x), mean=f(x), sd=0.025)

(base_plot() +
  add_psplines(x, y, 3, 0, 0)
) %>% show_plot()

Linear spline model, 10 knots

(base_plot() +
  add_psplines(x, y, 10, 0, 0)
) %>% show_plot()

Linear spline model, 20 knots

(base_plot() +
  add_psplines(x, y, 20, 0, 0)
) %>% show_plot()

Overfitting and penalties

Spline fitting with MSE is sensitive to number of knots. Too many knots lead to overfitting.

One solution to this problem is to include smoothness in measure of fit; this is known as "penalty".

Most common penalty is integral of second derivative squared, also known as 2nd order penalty.

Splines used for fitting with a penalty term are known as penalized splines, or P-splines.

Linear P-spline model, 20 knots

(base_plot() +
  add_psplines(x, y, 20, 0, 2)
) %>% show_plot()

Spline modeling: cyclic time

Splines aren't smooth at the “ends” of a time cycle.

set.seed(0)
f = function(x) cos(x)
x = seq(0, 2 * pi, by=0.01)
y = rnorm(length(x), mean=f(x), sd=0.025)

(base_plot() +
  add_psplines(x, y, 7, 2, 2, deriv=1)
) %>% show_plot()

Spline modeling: cyclic time

That is what cyclic splines are used for.

(base_plot() +
  add_psplines(x, y, 4, 2, 2, cyclic=TRUE, deriv=1)
) %>% show_plot()

Outbreak outcome estimation with P-splines

1. Define outcome measure

Onset = time corresponding to 2.5% of total cases

set.seed(0)
start = 20
end = 40
peak = 300

t.obs = seq(1, 52)
dt = 0.01

t = seq(min(t.obs) - 0.5, max(t.obs) + 0.52 - dt, by=dt)
cases0 = round((1 - cos((t - start) / (end - start) * 2 * pi)) / 2 * peak) * (t >= start) * (t <= end)
data0 = data.frame(time=t, cases=cases0)
m0 = gam(cases ~ s(x=time, k=20, bs="ps", m=3), family=poisson, data=data0)
cases.true = predict(m0, data.frame(time=t), type="response")
data.true = data.frame(time=t, cases=cases.true)

threshold = 0.025

cum.frac.true = cumsum(cases.true) / sum(cases.true)
t.onset.true = t[cum.frac.true <= threshold]
cases.onset.true = cases.true[cum.frac.true <= threshold]
data.onset.true = tail(data.frame(time=t.onset.true, cases=cases.onset.true), 1)

(base_plot() +
  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
  geom_segment(data=data.onset.true, aes(x=time, xend=time, y=0, yend=cases, color="true"), size=1) + 
  labs(x="Week", y="Cases")
) %>% show_plot()

2. Fit a P-spline model

cases.true0 = predict(m0, data.frame(time=t.obs), type="response")
cases.obs = rpois(length(cases.true0), cases.true0)
data.obs = data.frame(time=t.obs, cases=cases.obs)

m = gam(cases ~ s(x=time, k=8, bs="cp", m=3), family=poisson, data=data.obs)
cases.pred = predict(m, data.frame(time=t), type="response")
data.pred = data.frame(time=t, cases=cases.pred)

(base_plot() +
  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
  geom_point(data=data.obs, aes(x=time, y=cases, color="obs")) +
  geom_line(data=data.pred, aes(x=time, y=cases, color="est"), size=0.5) + 
  labs(x="Week", y="Cases")
) %>% show_plot()

3. Sample model parameters

library(pspline.inference)

sample_name = function(n) {
  sprintf("est%d", n)
}

n=5
cases.samples = pspline.sample.timeseries(m, data.frame(time=t), pspline.outbreak.cases, samples=n)
(base_plot(cases.samples) +
  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
  geom_point(data=data.obs, aes(x=time, y=cases, color="obs")) +
  geom_line(aes(x=time, y=cases, group=pspline.sample, color=sample_name(pspline.sample)), size=0.5) + 
  labs(x="Week", y="Cases")
) %>% show_plot()

3. Sample model parameters

zoom1 = coord_cartesian(x=c(20, 23.5), y=c(0, 80))

(base_plot(cases.samples) +
  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
  geom_point(data=data.obs, aes(x=time, y=cases, color="obs")) +
  geom_line(aes(x=time, y=cases, group=pspline.sample, color=sample_name(pspline.sample)), size=0.5) + 
  labs(x="Week", y="Cases") +
  zoom1
) %>% show_plot()

4. Sample outcome measure

onset.samples = cases.samples %>% 
  group_by(pspline.sample) %>%
  do((function(data){
    cases.frac = cumsum(data$cases) / sum(data$cases)
    data.frame(
      pspline.sample = tail(data$pspline.sample, 1),
      onset = tail(data$time[cases.frac <= threshold], 1),
      cases = tail(data$cases[cases.frac <= threshold], 1)
    )
  })(.)) %>%
  ungroup()

#(base_plot(cases.samples) +
#  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
#  geom_point(data=data.obs, aes(x=time, y=cases, color="obs")) +
#  geom_line(aes(x=time, y=cases, group=pspline.sample, color=sample_name(pspline.sample)), size=0.5) +
#  geom_segment(data=onset.samples, aes(x=onset, xend=onset, y=0, yend=cases, group=pspline.sample, color=sample_name(pspline.sample)), size=0.5) + 
#   labs(x="Week", y="Cases")
#) %>% show_plot()
(base_plot(cases.samples) +
  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
  geom_point(data=data.obs, aes(x=time, y=cases, color="obs")) +
  geom_line(aes(x=time, y=cases, group=pspline.sample, color=sample_name(pspline.sample)), size=0.5) +
  geom_segment(data=onset.samples, aes(x=onset, xend=onset, y=0, yend=cases, group=pspline.sample, color=sample_name(pspline.sample)), size=0.5) + 
  labs(x="Week", y="Cases") +
  zoom1

) %>% show_plot()

5. Infer outcome distribution

n.2k = 2000
onset.samples.2k = pspline.sample.scalars(m, data.frame(time=t), pspline.outbreak.thresholds(onset=0.025), samples=n.2k)
cases.est.2k = pspline.estimate.timeseries(m, data.frame(time=t), pspline.outbreak.cases, samples=n.2k)
(base_plot(cases.samples) +
  geom_ribbon(data=cases.est.2k, aes(x=time, ymin=cases.lower, ymax=cases.upper, fill="est-95-cl"), color=NA) +
  geom_density(data=onset.samples.2k, aes(x=onset, y=..density..*5, fill="est-dist"), color=NA, trim=TRUE) +
  geom_line(data=data.true, aes(x=time, y=cases, color="true"), size=1) + 
  geom_segment(data=data.onset.true, aes(x=time, xend=time, y=0, yend=cases, color="true"), size=1) + 
  geom_point(data=data.obs, aes(x=time, y=cases, color="obs")) +
  labs(x="Week", y="Cases") +
  guides(fill=guide_legend(), color="none", linetype="none") +
  zoom1
) %>% show_plot()

pspline.inference

pspline.estimate.scalars

model <- gam(..., data)
calc.outcome <- function(data) { ... }
outcome.estimate <- pspline.estimate.scalars(
  model, data, calc.outcome
)

pspline.validate.scalars

gen.truth <- function() { ... }
gen.observations <- function(truth) { ... }
gen.model <- function(observations) { ... }

pspline.validate.scalars

validation <- pspline.validate.scalars(
  gen.truth, n.truth, 
  gen.observations, n.observations, 
  gen.model, calc.outcome
)

RSV

Motivation

RSV is one of the top infectious causes of infant hospitalizations and mortality in developed countries.

Effective prophylaxis (palivizumab) is expensive, and therefore only recommended during RSV season.

Prophylaxis guidelines (by American Academy of Pediatrics) use the same schedule for all states except for Florida.

Is that schedule a good match for the actual RSV season timing in CT?

Motivation

htmltools::tags$img(
  src="coverageByRegionAAPInsight-Standalone.pdf",
  style="width: 768px;"
)

Questions

How many RSV cases occur during the AAP-recommended prophylaxis protection window?

How much can we increase that by adjusting the window based on state-level and county-level data?

Approach

If prophylaxis lasts from $T_\text{start}$ to $T_\text{end}$, protected fraction is defined as:

[ \begin{align} \text{protected } & \text{fraction}(T_\text{start}, T_\text{end}) = \ &\frac{\text{RSV cases occurring beteween } T_\text{start} \text{ and } T_\text{end}}{\text{all RSV cases}} \end{align} ]

Using pspline.inference:

  1. Estimate protected fraction for AAP prophlaxis guidelines.
  2. Estimate protected fraction for alternate prophylaxis schedule --- same duration, but lined up with RSV season and aligned to useful calendar intervals

Results

Protected fraction of AAP-recommended prophylaxis is 94.08% statewide (95% CI: 93.70 -- 94.42%)

Aligning to RSV season by county or statewide and rounding prophylaxis window to 1 or 2 weeks increases protected fraction by ~0.75-2%.

Least increase: 1-week rounding, season alignment by county = 94.81% statewide protected fraction (95% CI: 94.47 -- 95.12%).

Rounding schedule to 4 weeks gives no benefit over AAP. (AAP is as good as it gets for month-based schedule in CT.)

Adjusting for year-to-year variation in RSV season timing doesn't add any benefit.

Conclusions

Potential ~1% gain in protection by adjusting prophylaxis schedule to match local (state or county) season timing (rounded to 1- or 2-week intervals).

Need cost-benefit analysis to weigh 1% improvement against implementation complexity.

Analysis method generalizes to other states.

pspline.inference generalizes to other outcomes.



airbornemint/outbreak-inference documentation built on Jan. 28, 2021, 6:38 a.m.