R/apollo_cppScript.R

Defines functions apollo_cppScript

Documented in apollo_cppScript

#' Generate C++ function to calculate V
#' 
#' Returns an R function that calls a C++ function which calculates V efficiently.
#' 
#' V must be a list of functions.
#' 
#' @param apollo_probabilities Likelihood function of the whole model.
#' @param apollo_beta Named numeric vector of parameters to be estimated.
#' @param apollo_inputs List of arguments and settings generated by \link{apollo_validateInputs}.
#' @param V Named list of functions.
#' 
#' @return A function that receives three arguments:
#'         \itemize{
#'           # DataFrame db, NumericVector b, List drw
#'           \item \code{db}: Data.frame containing observations.
#'           \item \code{b}: Named numeric vector containing paramters to be estimated.
#'           \item \code{drw}: Named list containing draws, as contained inside \code{apollo_inputs}.
#'         }
#'         And returns a named list with numeric elements.
#'         
#' @importFrom Rcpp cppFunction
#' @export
apollo_cppScript <- function(apollo_probabilities, apollo_beta, apollo_inputs, V){
  # Check that all elements of V and randCoeff are functions
  if(!all(sapply(V, is.function))) stop("All utilities should be functions for them to be converted to C++")
  if(!all(sapply(apollo_inputs$randCoeff, is.function))) stop("All definitions of random coefficients should be functions for them to be converted to C++")
  
  # Build variable definitions
  vars <- apollo_varList(apollo_probabilities, apollo_beta, apollo_inputs, V, cpp=TRUE)
  if(is.null(vars)) return(NULL)
  nObs <- nrow(apollo_inputs$database)
  nDrw <- ifelse(anyNA(apollo_inputs$apollo_draws), 1, apollo_inputs$apollo_draws$interNDraws)
  dec <- c()
  for(b in vars$b) dec <- c(dec, paste0('  const double ', b, ' = b["', b, '"];') )
  for(x in vars$x) dec <- c(dec, paste0('  const NumericVector ', x, ' = db["', x, '"];') )
  for(d in vars$d) dec <- c(dec, paste0('  const NumericMatrix ', d, ' = drw["', d, '"];') )
  for(r in vars$r[,1]) dec <- c(dec, paste0('  double ', r, ';') )
  for(p in vars$p[,1]) dec <- c(dec, paste0('  double ', p, ';') )
  #for(i in names(V)) dec <- c(dec, paste0('  NumericMatrix V_', i, ' = L["', i, '"];') )
  for(i in names(V)) dec <- c(dec, paste0('  NumericMatrix V_', i, '(', nObs, ',', nDrw, ');') )
  dec <- paste0(dec, collapse="\n")
  
  # Check if any definition does not change with each draw and perform calculation outside of loop
  # Not implemented yet
  
  # Build loop over draws
  loop <- paste0("  for(int n=0; n<", nObs, "; ++n){")
  loop <- c(loop, paste0("    for(int d=0; d<", nDrw, "; ++d){"))
  if(!is.null(r)) loop <- c(loop, paste0("     ", vars$r[,1], " = ", vars$r[,2], ";") )
  if(!is.null(p)) loop <- c(loop, paste0("     ", vars$p[,1], " = ", vars$p[,2], ";") )
  # Avoid calculating V for unavailable alternatives, if availability is provided (think of how to match these)
  #v_x <- paste0("V_", vars$v[,1], "(n,d) = ")
  #loop <- c(loop,
  #          paste0("     if()"))
  loop <- c(loop, paste0("     V_", vars$v[,1], "(n,d) = ", vars$v[,2], ";") )
  for(l in 3:length(loop)){
    for(j in vars$d) if(regexpr(j, loop[l])>=0) loop[l] <- gsub(j, paste0(j,"(n,d)"), loop[l])
    for(j in vars$x) if(regexpr(j, loop[l])>=0) loop[l] <- gsub(j, paste0(j,"(n)"), loop[l])
  } 
  loop <- c(loop, "    }")
  loop <- c(loop, "  }")
  loop <- paste0(loop, collapse="\n")
  
  
  # Put utilities in a list
  # List L = List::create(Named("name1") = v1 , Named("name2) = v2);
  vlst <- "  List V = List::create("
  for(i in names(V)[1:(length(V)-1)]) vlst <- c(vlst, 'Named("', i, '") = V_', i, ", ")
  i <- names(V)[length(V)]
  vlst <- c(vlst, 'Named("', i, '") = V_', i, ");")
  vlst <- paste0(vlst, collapse="")
  
  # Put everything together
  #header <- paste(c("#include <Rcpp.h>", 
  #                  "using namespace Rcpp;",
  #                  "// [[Rcpp::export]]",
  #                  "void utilPre(DataFrame db, NumericVector b, List drw, List L) {"),
  #                sep="", collapse="\n")
  #script <- paste0(c(header, dec, loop, "}"), collapse="\n\n")
  header <- paste(c("#include <Rcpp.h>", 
                    "using namespace Rcpp;",
                    "// [[Rcpp::export]]",
                    "List utilPre(DataFrame db, NumericVector b, List drw, List L) {"),
                  sep="", collapse="\n")
  script <- paste0(c(header, dec, loop, vlst, "}"), collapse="\n\n")
  
  # Generate C++ code and return
  cppF <- tryCatch(Rcpp::cppFunction(script), error=function(e) return(NULL))
  return(cppF)
}
byu-transpolab/apollo-byu documentation built on Dec. 19, 2021, 12:49 p.m.