tests/bsplines2.R

options(digits=3)
png(filename="bsplines2-%02d.png",res=100)
library(tidyr)
library(dplyr)
library(ggplot2)
library(pomp)

trajectory(
  t0=0, times=seq(0,2,by=0.02),
  skeleton=vectorfield(Csnippet(r"{
    const double *a = &a1;
    double knots[] = {-3, -2, -1, 0, 1, 2, 3, 4, 5};
    double s[5], p[5];
    bspline_basis_eval(t,knots,3,5,s);
    periodic_bspline_basis_eval(t,1,3,5,p);
    Dx = dot_product(5,a,s);
    Dy = dot_product(5,a,&trend_1);
    Dz = dot_product(5,a,p);
    Dw = dot_product(5,a,&seas_1);
  }")),
  params = c(
    a1=-1,a2=1,a3=0,a4=-3,a5=18,
    x_0=0,y_0=0,z_0=0,w_0=0
  ),
  paramnames=c("a1","a2","a3"),
  statenames=c("x","y","w","z"),
  covar=covariate_table(
    times=seq(0,2.1,by=0.001),
    trend=bspline_basis(x=times,nbasis=5,degree=3,rg=c(0,2)),
    seas=periodic_bspline_basis(x=times,period=1,nbasis=5,degree=3)
  )
) |>
  plot()

trajectory(
  t0=0, times=seq(0,2,by=0.02),
  skeleton=vectorfield(Csnippet(r"{
    const double *a = &a1;
    double knots[] = {-3, -2, -1, 0, 1, 2, 3, 4, 5};
    double s[5], p[5];
    bspline_basis_eval_deriv(t,knots,3,5,1,s);
    periodic_bspline_basis_eval_deriv(t,1,3,5,1,p);
    Dx = dot_product(5,a,s);
    Dy = dot_product(5,a,&trend_1);
    Dz = dot_product(5,a,p);
    Dw = dot_product(5,a,&seas_1);
  }")),
  params = c(
    a1=-1,a2=1,a3=0,a4=-3,a5=18,
    x_0=0,y_0=0,z_0=0,w_0=0
  ),
  paramnames=c("a1","a2","a3"),
  statenames=c("x","y","w","z"),
  covar=covariate_table(
    times=seq(0,2.1,by=0.001),
    trend=bspline_basis(x=times,nbasis=5,degree=3,deriv=1,rg=c(0,2)),
    seas=periodic_bspline_basis(x=times,period=1,deriv=1,nbasis=5,degree=3)
  )
) |>
  as.data.frame() |>
  select(time,x,y,z,w) |>
  pivot_longer(-time) -> dat

covariate_table(
  times=seq(0,2.1,by=0.001),
  trend=bspline_basis(x=times,nbasis=5,degree=3,deriv=0,rg=c(0,2)),
  seas=periodic_bspline_basis(x=times,period=1,deriv=0,nbasis=5,degree=3)
) -> x

x@table |>
  melt() |>
  select(name=Var1,value,id=Var2) |>
  left_join(
    x@times |> melt() |> rename(time=value),
    by=c("id"="name")
  ) |>
  select(-id) |>
  separate(col=name,into=c("name","fn")) |>
  mutate(
    fn=as.integer(fn),
    value=c(a1=-1,a2=1,a3=0,a4=-3,a5=18)[fn]*value
  ) |>
  group_by(name,time) |>
  summarize(value=sum(value)) |>
  ungroup() |>
  arrange(name,time) |>
  group_by(name) |>
  mutate(value=value-value[1]) |>
  ungroup() -> x

dat |>
  bind_rows(x) |>
  filter(time<=2) |>
  mutate(name=factor(name,levels=c("trend","seas","x","w","y","z"))) |>
  ggplot(aes(x=time,y=value))+
  geom_line()+
  facet_wrap(~name,ncol=2,scales="free_y")+
  theme_bw()

dev.off()
kingaa/pomp documentation built on April 24, 2024, 11:25 a.m.