R/CGillespie.R

#' R6 class implementing the Gillespie method susing cpp11
#'  
#' This is a subclass of Simulator, implementing the Gillespie method
#' to simulate a compartmental model using cpp11. This class is only available
#' if the `cpp11` package is installed.
#' 
#' @docType class
#' @examples
#' # an SIR model
#' SIR = Compartmental$new(S, I, R, title="SIR")
#' SIR$transition(S->I ~ beta*S*I/N, N=S+I+R, name="infection")
#' SIR$transition(I->R ~ gamma*I, name="recovery")
#' g = CGillespie$new(SIR)
#' g$simulate(0:100, y0=c(S=1000, I=10, R=0), parms=c(beta=0.4,gamma=0.2))
#' @export
CGillespie = R6Class(
  "CGillespie",
  inherit = Simulator,
  private = list(
    # the C++ program
    .program = "",
    # the dynamic library generated by cpp11
    lib = NULL,
    # the gillespie function call generated by cpp11
    gillespie = NULL,

    # the header of the C++ program
    header = "#include \"cpp11.hpp\"
#include \"Rmath.h\"
",
    # the main part of the gillespie function for simulating
    # a single time step (an event).
    main = "[[cpp11::register]] 
cpp11::doubles_matrix<> gillespie(cpp11::doubles t, cpp11::doubles y0, cpp11::doubles parms) {
  cpp11::writable::doubles y(y0);
  cpp11::writable::doubles_matrix<> data(t.size(), y.size() + 1);
  if (!y.named())
    cpp11::stop(\"the initial values must be named\");
  double time;
  time = t[0];
  size_t i = 0;
  while (i < t.size()) {
    auto l = step(time, y, parms);
    time = l[0];
    while (i < t.size() && time >= t[i]) {
      data(i, 0) = t[i];
      for (size_t j = 1; j <= y.size(); ++j) {
        double v = y[j-1];
        data(i, j) = v;
      }
      ++i;
    }
    for (size_t j = 0; j < y.size(); ++j)
      y[j] = l[j+1];
  }
  return data;
}",
    # the C++ code for copying the initial states to the results
    result = list(
      "cpp11::writable::doubles __result(__y.size() + 1);",
      "for (size_t i = 0; i < __y.size(); ++ i)",
      "  __result[i+1] = __y[i];"
    ),
    # the C++ code for finding the next event and its event time
    middle = list(
      "if (__total == 0 || std::isnan(__total)) {",
      "  __result[0] = R_PosInf;",
      "  return __result;",
      "}",
      "__result[0] = t + exp_rand()/__total;",
      "double __p = unif_rand() * __total;",
      "size_t __transition = 0;",
      "for (; __rate[__transition] <= __p; ++ __transition);"
    ),

    # remove the bracket from a bracketed expression
    remove.bracket = function(x) {
      if (is.call(x) && x[[1]] == "(")
        private$remove.bracket(x[[2]]) else x
    },

    # convert an R expression to C++ statements.
    cpp = function(C) {
      if (is.call(C)) {
        if (C[[1]] == "[[") C[[1]] = "["
        switch(
          as.character(C[[1]]),
          `+` = if (length(C) == 2) private$cpp(C[[2]]) else
            paste0(private$cpp(C[[2]]), " + ", private$cpp(C[[3]])),
          `-` = if (length(C) == 2) {
            paste("-", C[[2]])
          } else paste0(private$cpp(C[[2]]), " - ", private$cpp(C[[3]])),
          `*` = paste0(private$cpp(C[[2]]), " * ", private$cpp(C[[3]])),
          `/` = paste0(private$cpp(C[[2]]), " / ", private$cpp(C[[3]])),
          `^` = paste0("pow(", 
                       private$cpp(private$remove.bracket(C[[2]])),
                       ", ",
                       private$cpp(private$remove.bracket(C[[3]])), 
                       ")"),
          `(` = paste0("(", private$cpp(C[[2]]), ")"),
          `[` = if (length(C) > 3) private$cpp(C[-1]) else
            paste0(private$cpp(C[[2]]), "[", private$cpp(C[[3]]), "]"),
          paste0(private$cpp(C[[1]]), "(", 
                 do.call(paste, c(lapply(C[-1], self$format), sep=", ")),
                 ")")
        )
      } else as.character(C)
    },
    
    # generate a C++ assignment statement
    assign = function(var, value, type = "") {
      s = paste0(var, " = ", value, ";")
      if (type != "") paste(type, s) else s
    },
    
    # generate a C++ array access (by index)
    array = function(var, index) {
      paste0(var, "[", index, "]")
    },
    
    # generate a C++ statement for array decleration
    declare_array = function(var, length, type="double") {
      paste0(type, " ", var, "[", length, "];")
    },

    # format the transition rate calculation as C++ statements
    format.rate = function(T) {
      l = list(private$declare_array("__rate", length(T)))
      for (i in 1:length(T)) {
        l = c(l, private$assign(
          private$array("__rate", i-1),
          private$cpp(if (i > 1) {
            call("+", private$array("__rate", i-2), T[[i]]$rate)
          } else T[[i]]$rate)
        ))
      }
      c(l, private$assign("__total", private$array("__rate", length(T)-1), "double"))
    },
    
    # generate a C++ block
    block = function(statements, indent="") {
      paste0(indent, "{", "\n",
             paste(paste(indent, "  ", statements), collapse="\n"), 
             "\n", indent, "}")
    },

    # format an calculations as C++ statements.
    format.equation = function(eq) {
      var = eq[[2]]
      if (is.call(var)) {
        if (var[[1]] != "'")
          stop("Invalid equation ", eq)
        var = as.name(paste0("_d_", var[[2]]))
      }
      value = private$cpp(eq[[3]])
      private$assign(var, value, "double")
    },
    
    # format the access to a named vector as a C++ statement
    format.var = function(S, name) {
      lapply(1:length(S), function(i) {
        private$assign(S[[i]], private$array(paste0("__", name), i-1), "double")
      })
    },
    
    # format a C++ switch statement
    format.switch = function(var, values, indent) {
      l = list()
      for(i in 1:length(values)) {
        n = names(values)[i]
        if (is.null(n)) n = i - 1
        l = c(l, paste0("case ", n, ":"),
              paste0("  ", values[[i]]), "  break;")
      }
      paste0("switch(", var, ") ", private$block(l, indent))
    },
    
    # format state transitions as the body of a C++ swtich statement
    format.transition = function(transitions) {
      l = lapply(transitions, function(tr) {
        l = list()
        if (!is.null(tr$from))
          l = c(l, private$assign(
            private$array(
              "__result", 
              which(private$compartments == tr$from)
            ),
            paste0(tr$from, " - 1")
          ))
        if (!is.null(tr$to))
          l = c(l, private$assign(
            private$array(
              "__result", 
              which(private$compartments == tr$to)
            ),
            paste0(tr$to, " + 1")
          ))
        l
      })
      private$format.switch("__transition", unname(l), indent = "  ")
    },
    
    # build the C++ program for simulating a single event
    build = function(model) {
      l = c(
        private$format.substitution(),
        private$result,
        private$format.rate(model$transitions),
        private$middle,
        private$format.transition(model$transitions),
        "return(__result);"
      )
      paste0("inline cpp11::doubles step(double t, cpp11::doubles __y, cpp11::doubles __parms) ",
             private$block(l, ""))
    },
    
    # order a numeric vector
    order = function(x, order) {
      xx = as.numeric(x[order])
      names(xx) = order
      xx
    },
    
    # perform the simulation.
    run = function(t, y, parms) {
      mistype = which(is.na(y) | y != as.integer(y) | y < 0)
      if (length(mistype) != 0) {
        if (length(mistype) > 1) {
          s = "s"
          a = ""
        } else {
          s = ""
          a = "a "
        }
        stop("the initial condition", s, " for ",
             paste(names(y[mistype]), collapse=", "),
             " must be ", a, "nonnegative integer", s, ".")
      }
      y = private$order(y, private$compartments)
      parms = private$order(parms, private$parameters)
      data = as.data.frame(private$gillespie(as.numeric(t), y, parms))
      colnames(data) = c("time", names(y))
      data
    }
  ),

  public = list(
    #' @description constructor
    #' @param model an object of the `Compartmental` class
    #' @details the CGillespie class is only available if the cpp11 
    #' package is installed.
    #' # an SIR model
    #' library(cpp11) # cpp11 is required to use CGillespie
    #' SIR = Compartmental$new(S, I, R, N=S+I+R, title="SIR")
    #' SIR$transition(S->I ~ beta*S*I/N, name="infection")
    #' SIR$transition(I->R ~ gamma*I, name="recovery")
    #' g = CGillespie$new(SIR)
    #' g$simulate(0:100, y0=c(S=1000, I=10, R=0), parms=c(beta=0.4,gamma=0.2))
    initialize = function(model) {
      if (!require(cpp11, quietly = TRUE)) {
        stop("cpp11 is required to use CGillespie. Either install the package, or use RGillespie instead", 
             call. = FALSE)
      }
      super$initialize(model)
      private$.program = paste(
        private$header,
        self$model,
        private$main,
        sep = "\n"
      )
      private$lib = cpp_source(
        code = private$.program)
      private$gillespie = gillespie
    },

    #' @description the destructor
    #' @details The destructor unload the dynamic library loaded by cpp11
    #' after compiling the C++ simulation code.
    finalize = function() {
      dyn.unload(private$lib[["path"]])
    }
  ),
)
junlingm/REpiSim documentation built on Nov. 28, 2023, 2:35 a.m.