inst/AD_test_utils.R

# source(system.file(file.path('tests', 'testthat', 'test_utils.R'), package = 'nimble'))

# An environment for storing some defaults for AD testing.
# Having this allows one to change these values globally.
ADtestEnv <- new.env()
resetTols <- function() {
  ADtestEnv$RRrelTol <<- c(1e-8, 1e-3, 1e-2)
  ADtestEnv$RCrelTol <<- c(1e-15, 1e-8, 1e-3)
  ADtestEnv$CCrelTol <<- rep(sqrt(.Machine$double.eps), 3)
}
resetTols()

#####################
## knownFailure utils
#####################

## test_param_name: an AD test parameterization name as produced by
##                  make_op_param()
##
## returns: a length 2 character vector with the op and args
##
split_test_param_name <- function(test_param_name) {
  first_space <- regexpr(' ', test_param_name)
  op <- substr(test_param_name, 1, first_space - 1)
  args <- substr(test_param_name, first_space + 1, nchar(test_param_name))
  return(c(op, args))
}

is_failure <- function(test_param_name, failure_name, knownFailures = list()) {
  if (length(knownFailures) == 0) return(FALSE)
  op_and_args <- split_test_param_name(test_param_name)
  op <- op_and_args[1]
  args <- op_and_args[2]
  if (!is.null(knownFailures[[op]])) {
    return(
      isTRUE(knownFailures[[op]][[args]][[failure_name]]) ||
        isTRUE(knownFailures[[op]][['*']][[failure_name]])
    )
  }
  return(FALSE)
}

## test_param_name: an AD test parameterization name as produced by
##                  make_op_param()
## knownFailures:   a list of known test failures, e.g. in the format of
##                  AD_knownFailures
##
## returns: TRUE if the test param is known to lead to a compilation failure and
##          FALSE otherwise
##
is_compilation_failure <- function(test_param_name, knownFailures = list()) {
  return(is_failure(test_param_name, 'compilation', knownFailures))
}

is_segfault_failure <- function(test_param_name, knownFailures = list()) {
  return(is_failure(test_param_name, 'segfault', knownFailures))
}

is_method_failure <- function(test_param_name, method_name,
                              output_name = c('value', 'jacobian', 'hessian'),
                              knownFailures = list()) {
  if (length(knownFailures) == 0) return(FALSE)
  output_name <- match.arg(output_name)
  op_and_args <- split_test_param_name(test_param_name)
  op <- op_and_args[1]
  args <- op_and_args[2]
  if (!is.null(knownFailures[[op]])) {
    this_arg_fails <- all_args_fail <- FALSE
    if (!is.null(knownFailures[[op]][[args]]))
      this_arg_fails <- identical(
        knownFailures[[op]][[args]][[output_name]],
        method_name
      )
    if (!is.null(knownFailures[[op]][['*']]))
      all_args_fail <- identical(
        knownFailures[[op]][['*']][[output_name]],
        method_name
      )
    return(this_arg_fails || all_args_fail)
  }
  return(FALSE)
}

all_equal_ignore_zeros <- function(v1, v2, ...) {
  if(length(v1) != length(v2)) return(FALSE)
  zv1 <- as.logical(v1 == 0) ## as.logical will strip any names
  zv2 <- as.logical(v2 == 0)
  nzboth <- !(zv1 & zv2)
  if(sum(nzboth)==0) return(FALSE) ## This is used to validate expected discrepancies except when they are zero, so if all are zero, we default to that there were discrepancies, which allows the test to pass.
  all.equal(v1[nzboth], v2[nzboth], ...)
}

########################
## Main AD testing utils
########################

test_AD2_oneCall <- function(Robj, Cobj,
                             recordArgs,
                             testArgs,
                             order,
                             wrt = NULL,
                             check.equality = TRUE,
                             # In normal usage, the tols are set from test_AD2 and
                             # these defaults are not used
                             RRrelTol = c(1e-8, 1e-3, 1e-2, 1e-14),
                             RCrelTol = c(1e-15, 1e-8, 1e-3, 1e-14),
                             CCrelTol = c(rep(sqrt(.Machine$double.eps), 3), 1e-14),
                             Rmodel = NULL, Cmodel = NULL,
                             recordInits = NULL, testInits = NULL,
                             nodesToChange = NULL) {
  resRecord <- list()
  resTest <- list()

  useRmodel <- !is.null(Rmodel)
  useCmodel <- !is.null(Cmodel)

  RRabsThresh <- 0
  RCabsthresh <- 0
  CCabsThresh <- 0
  if(length(RRrelTol) > 3) RRabsThresh <- RRrelTol[4]
  if(length(RCrelTol) > 3) RCabsThresh <- RCrelTol[4]
  if(length(CCrelTol) > 3) CCabsThresh <- CCrelTol[4]

  if(is.null(names(RRrelTol))) names(RRrelTol)[1:3] <- c("0", "1", "2")
  if(is.null(names(RCrelTol))) names(RCrelTol)[1:3] <- c("0", "1", "2")
  if(is.null(names(CCrelTol))) names(CCrelTol)[1:3] <- c("0", "1", "2")

  setModelInits <- function(m, inits) {
    for(v in names(inits)) {
      m[[v]] <- inits[[v]]
    }
    m$calculate()
  }

  changeModelInits <- function(m, inits, nodesToChange) {
    for(node in nodesToChange) { # This will not include constant nodes and can include like 'x[2:3]'
      if(all.vars(parse(text = node))[1] %in% names(inits))
        eval(parse(text = paste0("m$", node, " <- inits$", node)))
    }
    m$calculate()
  }

  do_one_call <- function(fun, argList) {
    eval( as.call( c(fun, argList) ) )
  }

  do_one_set <- function(method, order, metaLevel = 0, name,
                         doR, doC, tapingLevel = c(1, 0, 2), fixedOrder = FALSE) {
    # tapingLevel: 0 for the (run) function itself, 1 for a tape of run, 2 for a tape of a tape of run
    # metaLevel: 0 if order 0 from the tape is order 0 result.  1 if order 0 gives order 1 result due to double taping.  2 if order 0 gives order 2 result due to double taping.
    # fixedOrder: FALSE is return object is result of nimDerivs call.
    #             0, 1, or 2 if return object is the value, jacobian or hessian directly.
    if(tapingLevel == 0) order <- 0
    if(isFALSE(fixedOrder))
      metaOrder <- order[order >= metaLevel] - metaLevel # e.g. for metaLevel 1, count 1 as 0
    else
      metaOrder <- fixedOrder
    # Follow construction is because Robj[[method]] doesn't work until
    # Robj$run (if method == "run") has been done.  It's a flaw in reference classes
    Rfun <- eval( parse( text = paste0("Robj$", method), keep.source = FALSE)[[1]])
    Cfun <- eval( parse( text = paste0("Cobj$", method), keep.source = FALSE)[[1]])
    extraArgs <- list()
    if(tapingLevel >= 1) {
      extraArgs$wrt <- wrt_all
      # If there is a tape and no fixed order, the order arg is needed
      if(isFALSE(fixedOrder))
        extraArgs$order <- metaOrder
      extraArgs$reset <- TRUE
    }
    if(tapingLevel == 2) {
      extraArgs$innerWrt = wrt_all
    }
    argList <- c(recordArgs, extraArgs)
    if(doR) {
      if(useRmodel) setModelInits(Rmodel, recordInits)
      RansRecord <- do_one_call(Rfun, argList)
    }
    if(doC) {
      if(useCmodel) setModelInits(Cmodel, recordInits)
      CansRecord <- do_one_call(Cfun, argList)
    }
    if(tapingLevel >= 1)
      extraArgs$reset <- FALSE

    argList <- c(testArgs, extraArgs)
    if(doR) {
      if(useRmodel) changeModelInits(Rmodel, testInits, nodesToChange)
      RansTest <- do_one_call(Rfun, argList)
    }
    if(doC) {
      if(useCmodel) changeModelInits(Cmodel, testInits, nodesToChange)
      CansTest <- do_one_call(Cfun, argList)
    }

    pieces <- c("0"="value", "1"="jacobian", "2"="hessian")
    rName <- paste0("R", name)
    cName <- paste0("C", name)

    if(!isFALSE(fixedOrder)) {
      oChar <- as.character(fixedOrder) # Assume fixedOrder given directly (no need to adjust for metaOrder)
      if(doR) {
        resRecord[[oChar]][[rName]] <<- RansRecord
        resTest[[oChar]][[rName]] <<- RansTest
      }
      if(doC) {
        resRecord[[oChar]][[cName]] <<- CansRecord
        resTest[[oChar]][[cName]] <<- CansTest
      }
    } else {
      for(o in as.character(metaOrder)) {
        oChar <- as.character(as.numeric(o) + metaLevel)
        if((metaLevel == 1) & o == '1') {
          argsLength <- length(wrt_all)
          reorder_jac_jac <- function(x) {
            outLength <- length(x) / (argsLength*argsLength)
            aperm(array(as.numeric(x), c(outLength, argsLength, argsLength)), c(2, 3, 1))
          }
        }
        if(doR) {
          if((metaLevel == 1) & o == '1') {
            resRecord[[oChar]][[rName]] <<- reorder_jac_jac(RansRecord[[ pieces[[o]] ]])
            resTest[[oChar]][[rName]] <<- reorder_jac_jac(RansRecord[[ pieces[[o]] ]])
          } else {
            resRecord[[oChar]][[rName]] <<- RansRecord[[ pieces[[o]] ]]
            resTest[[oChar]][[rName]] <<- RansTest[[ pieces[[o]] ]]
          }
        }
        if(doC) {
          if((metaLevel == 1) & o == '1') {
            resRecord[[oChar]][[cName]] <<- reorder_jac_jac(CansRecord[[ pieces[[o]] ]])
            resTest[[oChar]][[cName]] <<- reorder_jac_jac(CansTest[[ pieces[[o]] ]])
          } else {
            resRecord[[oChar]][[cName]] <<- CansRecord[[ pieces[[o]] ]]
            resTest[[oChar]][[cName]] <<- CansTest[[ pieces[[o]] ]]
          }
        }
      }
    }
  }

  len_record <- sum(unlist(lapply(recordArgs, length)))
  len_test <- sum(unlist(lapply(testArgs, length)))
  if(len_record != len_test) {
    stop('lengths of recordArgs and testArgs must match.')
  }

  if(is.null(wrt))
    wrt_all <- 1:len_record
  else
    wrt_all <- wrt

  do_one_set("run", name = "run",
             doR = TRUE, doC = TRUE, tapingLevel = 0, fixedOrder = 0)
  do_one_set("derivsRun", order = order, metaLevel = 0, name = "derivsRun",
             doR = TRUE, doC = TRUE, tapingLevel = 1, fixedOrder = FALSE)
  do_one_set("value", name = "value",
             doR = FALSE, doC = TRUE, tapingLevel = 1, fixedOrder = 0)
  do_one_set("jac", name = "jacR1",
             doR = FALSE, doC = TRUE, tapingLevel = 1, fixedOrder = 1)
  do_one_set("jac2", name = "jacF1",
             doR = FALSE, doC = TRUE, tapingLevel = 1, fixedOrder = 1)
  do_one_set("hess", name = "hess",
             doR = FALSE, doC = TRUE, tapingLevel = 1, fixedOrder = 2)
  do_one_set("derivsValue", order = 0:2, name = "derivsValue",
             doR = FALSE, doC = TRUE, tapingLevel = 2, fixedOrder = FALSE)
  do_one_set("derivsJac", order = 1:2, name = "derivsJacR1", metaLevel = 1,
             doR = FALSE, doC = TRUE, tapingLevel = 2, fixedOrder = FALSE)
  do_one_set("derivsJac2", order = 1:2, name = "derivsJacF1", metaLevel = 1,
             doR = FALSE, doC = TRUE, tapingLevel = 2, fixedOrder = FALSE)
  do_one_set("derivsHess", order = 2, name = "derivsHess", metaLevel = 2,
             doR = FALSE, doC = TRUE, tapingLevel = 2, fixedOrder = FALSE)

  all_equal_list <- function(first, others, tol,
                             abs_threshold, info = "") {
    pass <- TRUE
    for(i in seq_along(others)) {
      pass <- pass && nim_all_equal(as.numeric(first), as.numeric(others[[i]]),
                                    tol, abs_threshold = abs_threshold,
                                    verbose = TRUE, info = info)
    }
    pass
  }

  if(check.equality) {
    pass <- TRUE
    for(res in list(resRecord, resTest))
      for(o in as.character(order)) {
        ansSet <- res[[o]]
        splitAnsSet <- split(ansSet,
                             c("C", "R")[as.integer(grepl("^R", names(ansSet)))+1])
        RansSet <- splitAnsSet[["R"]]
        CansSet <- splitAnsSet[["C"]]
        if(pass)
        if(length(RansSet) > 1) {
          pass <- pass && all_equal_list(RansSet[[1]], RansSet[-1], tol = RRrelTol[[o]],
                                         abs_threshold = RRabsThresh,
                                         info = paste0("(RR order ", o,")"))
          if(!pass) {
            cat(paste('Some R-to-R derivatives do not match for order',o))
            # browser()
          }
        }
        if(pass)
        if(length(RansSet) > 0 && length(CansSet) > 0) {
          pass <- pass && all_equal_list(RansSet[[1]], CansSet, tol = RCrelTol[[o]],
                                         abs_threshold = RCabsThresh,
                                         info = paste0("(RC order ", o,")"))
          if(!pass) {
            cat(paste('Some C-to-R derivatives to not match for order',o))
            browser()
          }
        }
        if(pass)
        if(length(CansSet) > 1) {
          pass <- pass && all_equal_list(CansSet[[1]], CansSet[-1], CCrelTol[[o]],
                                         abs_threshold = CCabsThresh,
                                         info = paste0("(CC order ", o, ")"))
          if(!pass) {
            cat(paste('Some C-to-C derivatives to not match for order',o))
            browser()
          }
        }
      }
    expect_true(pass)
  }
  list(resRecord = resRecord, resTest = resTest)
}

test_AD2 <- function(param, dir = file.path(tempdir(), "nimble_generatedCode"),
                     control = list(), verbose = nimbleOptions('verbose'),
                     catch_failures = FALSE, seed = NULL,
                     nimbleProject_name = '', return_compiled_nf = FALSE,
                     knownFailures = list()) {
  if (!is.null(param$debug) && param$debug) browser()
  if (verbose) cat(paste0("### Testing ", param$name, "\n"))

  ## by default, reset the seed for every test
  if(is.numeric(param[['seed']])) set.seed(param[['seed']])
  else if(is.numeric(seed)) set.seed(seed)

  nf <- nimbleFunction(
    setup = function() {},
    run = param$run,
    methods = param$methods,
    buildDerivs = param$buildDerivs
  )
  Robj <- nf()
  temporarilyAssignInGlobalEnv(Robj)

  if(!is.null(param$inputs)) {
    inputsRecord <- param$inputs[[1]]
    inputsTest <- param$inputs[[2]]
  } else {
    ## input_gen_funs might be generalized.
    opParam <- param$opParam
    if (is.null(param$input_gen_funs) || is.null(names(param$input_gen_funs)))
      if (length(param$input_gen_funs) <= 1) {
        inputsRecord <- lapply(opParam$args, arg_type_2_input,
                               input_gen_fun = param$input_gen_funs, size = param$size)
        inputsTest <- lapply(opParam$args, arg_type_2_input,
                             input_gen_fun = param$input_gen_funs, size = param$size)
      }
    else
      stop(
        'input_gen_funs of length greater than 1 must have names',
        call. = FALSE
      )
    else {
      inputsRecord <- sapply(
        names(opParam$args),
        function(name)
          arg_type_2_input(opParam$args[[name]],  input_gen_fun = param$input_gen_funs[[name]],
                           size = param$size[[name]]),
        simplify = FALSE
      )
      inputsTest <- sapply(
        names(opParam$args),
        function(name)
          arg_type_2_input(opParam$args[[name]],  input_gen_fun = param$input_gen_funs[[name]],
                           size = param$size[[name]]),
        simplify = FALSE
      )
    }
  }
  ##
  ## generate inputs that depend on the other inputs
  ##
  is_fun <- sapply(inputsRecord, is.function)
  inputsRecord[is_fun] <- lapply(
    inputsRecord[is_fun], function(fun) {
      eval(as.call(c(fun, inputsRecord[names(formals(fun))])))
    }
  )
  is_fun <- sapply(inputsTest, is.function)
  inputsTest[is_fun] <- lapply(
    inputsTest[is_fun], function(fun) {
      eval(as.call(c(fun, inputsTest[names(formals(fun))])))
    }
  )

  # Currently only whole arguments supported.
  # Hence all this does is create indices that skip over non-wrt args
  lens <- lapply(inputsRecord, length)
  if(!is.null(param$wrt_args)) {
    nextInd <- 1
    lenInds <- lapply(lens,
                      function(x) {
                        ans <- nextInd:(nextInd + x - 1); nextInd <<- nextInd + x; ans})
    wrt <- unlist(lenInds[param$wrt_args])
  } else {
    wrt <- 1:(sum(unlist(lens)))
  }

  if (is_segfault_failure(param$name, knownFailures)) {
    if (verbose) cat("## Skipping the rest of test before compilation",
                     "due to known segmentation fault\n")
    return(invisible(NULL))
  }

  if (!is.null(param$dir)) dir <- param$dir
  compilation_fails <- is_compilation_failure(param$name, knownFailures)
  Cobj <- param$Cobj ## user provided compiled nimbleFunction?
  if (is.null(Cobj)) {
    if (verbose) cat("## Compiling nimbleFunction \n")
    Cobj <- wrap_if_true(compilation_fails, expect_error, {
      compileNimble(
        Robj, dirName = dir, projectName = nimbleProject_name,
        control = control
      )
    }, wrap_in_try = isTRUE(catch_failures))
  }
  if (isTRUE(catch_failures) && inherits(Cobj, 'try-error')) {
    warning(
      paste0(
        'The test of ', opParam$name,
        ' failed to compile.\n'
      ),
      call. = FALSE,
      immediate. = TRUE
    )
    ## stop the test here because it didn't compile
    return(invisible(NULL))
  } else if (compilation_fails) {
    if (verbose) cat("## Compilation failed, as expected \n")
    return(invisible(NULL))
  } else {
    RRrelTol <- param$RRrelTol
    RCrelTol <- param$RCrelTol
    CCrelTol <- param$CCrelTol
    if(is.null(RRrelTol)) RRrelTol <- ADtestEnv$RRrelTol
    if(is.null(RCrelTol)) RCrelTol <- ADtestEnv$RCrelTol
    if(is.null(CCrelTol)) CCrelTol <- ADtestEnv$CCrelTol

    res <- try(test_AD2_oneCall(Robj = Robj, Cobj = Cobj,
                                recordArgs = inputsRecord, testArgs = inputsTest,
                                order = 0:2,
                                wrt = wrt,
                                RRrelTol = RRrelTol,
                                RCrelTol = RCrelTol,
                                CCrelTol = CCrelTol))
    if (inherits(res, 'try-error')) {
      msg <- paste(
        'Something failed in test', opParam$name, '.\n'
      )
      if (isTRUE(catch_failures)) ## continue to compilation
        warning(msg, call. = FALSE, immediate. = TRUE)
      else
        stop(msg, call. = FALSE) ## throw an error here
    }

  }
  if(!compilation_fails) {
    nimble:::clearCompiled(Cobj)
  }
  list(res = res)
}
## Take a test parameterization created by make_AD_test() or
## make_distribution_fun_AD_test(), generate a random input, and test for
## matching nimDerivs outputs from uncompiled and compiled versions of a
## nimbleFunction.
##
## param:              an test parameterization generated by make_AD_test() or
##                     make_distribution_fun_AD_test()
## dir:                passed to compileNimble() as the dirName argument
## control:            passed to compileNimble() as the control argument
## verbose:            if TRUE, print messages while testing
## catch_failures:     if TRUE, don't stop testing when a testthat expect_*
##                     fails
## seed:               seed to use in set.seed() before generating random inputs
## nimbleProject_name: passed to compileNimble() as the projectName argument
## return_compiled_nf: if TRUE, don't call clearCompiled() and include the
##                     compiled nimbleFunction instance in the output
##
## returns: a list with the randomly generated input and possibly the compiled
##          nimbleFunction instance, or NULL if the test has a known compilation
##          failure
test_AD <- function(param, dir = file.path(tempdir(), "nimble_generatedCode"),
                    control = list(), verbose = nimbleOptions('verbose'),
                    catch_failures = FALSE, seed = 0,
                    nimbleProject_name = '', return_compiled_nf = FALSE,
                    knownFailures = list()) {
  if (!is.null(param$debug) && param$debug) browser()
  if (verbose) cat(paste0("### Testing ", param$name, "\n"))

  ## by default, reset the seed for every test
  if (is.numeric(seed)) set.seed(seed)

  ## TODO: use nClass instead?
  nf <- nimbleFunction(
    setup = function() {},
    run = param$run,
    methods = param$methods,
    buildDerivs = param$buildDerivs
  )
  nfInst <- nf()
  temporarilyAssignInGlobalEnv(nfInst)

  ##
  ## generate inputs for the nimbleFunction methods
  ##
  opParam <- param$opParam
  if (is.null(param$input_gen_funs) || is.null(names(param$input_gen_funs)))
    if (length(param$input_gen_funs) <= 1)
      input <- lapply(opParam$args, arg_type_2_input, param$input_gen_funs)
  else
    stop(
      'input_gen_funs of length greater than 1 must have names',
      call. = FALSE
    )
  else {
    input <- sapply(
      names(opParam$args),
      function(name)
        arg_type_2_input(opParam$args[[name]], param$input_gen_funs[[name]]),
      simplify = FALSE
    )
  }
  ##
  ## generate inputs that depend on the other inputs
  ##
  is_fun <- sapply(input, is.function)
  input[is_fun] <- lapply(
    input[is_fun], function(fun) {
      eval(as.call(c(fun, input[names(formals(fun))])))
    }
  )

  if (is_segfault_failure(param$name, knownFailures)) {
    if (verbose) cat("## Skipping the rest of test before compilation",
                     "due to known segmentation fault\n")
    return(invisible(NULL))
  }

  ##
  ## compile the nimbleFunction
  ##
  if (!is.null(param$dir)) dir <- param$dir
  compilation_fails <- is_compilation_failure(param$name, knownFailures)
  CnfInst <- param$CnfInst ## user provided compiled nimbleFunction?
  if (is.null(CnfInst)) {
    if (verbose) cat("## Compiling nimbleFunction \n")
    CnfInst <- wrap_if_true(compilation_fails, expect_error, {
      compileNimble(
        nfInst, dirName = dir, projectName = nimbleProject_name,
        control = control
      )
    }, wrap_in_try = isTRUE(catch_failures))
  }
  if (isTRUE(catch_failures) && inherits(CnfInst, 'try-error')) {
    warning(
      paste0(
        'The test of ', opParam$name,
        ' failed to compile.\n', CnfInst[1]
      ),
      call. = FALSE,
      immediate. = TRUE
    )
    ## stop the test here because it didn't compile
    return(invisible(NULL))
  } else if (compilation_fails) {
    if (verbose) cat("## Compilation failed, as expected \n")
  } else {
    ##
    ## call R versions of nimbleFunction methods with generated input
    ##
    if (verbose) cat("## Calling R versions of nimbleFunction methods\n")
    Rderivs <- try(
      sapply(names(param$methods), function(method) {
        do.call(
          ## can't access nfInst[[method]] until $ has been used :(
          eval(substitute(nfInst$method, list(method = as.name(method)))), input
        )
      }, USE.NAMES = TRUE), silent = TRUE
    )
    if (inherits(Rderivs, 'try-error')) {
      msg <- paste(
        'Calling R version of test', opParam$name,
        'resulted in an error:\n', Rderivs[1]
      )
      if (isTRUE(catch_failures)) ## continue to compilation
        warning(msg, call. = FALSE, immediate. = TRUE)
      else
        stop(msg, call. = FALSE) ## throw an error here
    }

    ##
    ## call compiled nimbleFunction methods with generated input
    ##
    Cderivs <- sapply(names(param$methods), function(method) {
      do.call(
        ## same issue as with Rderivs
        eval(substitute(CnfInst$method, list(method = as.name(method)))),
        input
      )
    }, USE.NAMES = TRUE)
    if ('log' %in% names(opParam$args)) {
      input2 <- input
      input2$log <- as.numeric(!input$log)
      Rderivs2 <- try(
        sapply(names(param$methods), function(method) {
          do.call(nfInst[[method]], input2)
        }, USE.NAMES = TRUE), silent = TRUE
      )
      Cderivs2 <- sapply(names(param$methods), function(method) {
        do.call(CnfInst[[method]], input2)
      }, USE.NAMES = TRUE)
    }

    ##
    ## loop over test methods (each with a different wrt arg)
    ##
    ## set expect_equal tolerances
    tol1 <- if (is.null(param$tol1)) 1e-8 else param$tol1
    tol2 <- if (is.null(param$tol2)) 1e-6 else param$tol2
    tol3 <- if (is.null(param$tol3)) 1e-4 else param$tol3
    for (method_name in names(param$methods)) {
      info <- paste0(
        param$name, ', wrt = ',
        paste0(param$wrts[[method_name]], collapse = ', ')
      )
      if (verbose) {
        cat(paste0(
          "## Testing ", method_name, ': ',
          paste0(param$wrts[[method_name]], collapse = ', '),
          '\n'
        ))
      }
      ##
      ## test values
      ##
      value_test_fails <- is_method_failure(
        param$name, method_name, 'value', knownFailures
      )
      value_test <- wrap_if_true(value_test_fails, expect_failure, {
        if (verbose) cat("## Checking values\n")
        expect_equal(
          Cderivs[[method_name]]$value,
          Rderivs[[method_name]]$value,
          tolerance = tol1, info = info
        )
        if ('log' %in% names(opParam$args)) {
          if (verbose) cat("## Checking log behavior for values\n")
          expect_equal(
            Cderivs2[[method_name]]$value,
            Rderivs2[[method_name]]$value,
            tolerance = tol1, info = info
          )
          expect_false(isTRUE(all.equal(
            Rderivs[[method_name]]$value,
            Rderivs2[[method_name]]$value,
            tolerance = tol1)), info = info
          )
          expect_false(isTRUE(all.equal(
            Cderivs[[method_name]]$value,
            Cderivs2[[method_name]]$value,
            tolerance = tol1)), info = info
          )
        }
      }, wrap_in_try = isTRUE(catch_failures))
      if (isTRUE(catch_failures) && inherits(value_test, 'try-error')) {
        warning(
          paste0(
            'There was something wrong with the values of ',
            opParam$name, ' with wrt = c(',
            paste0(param$wrts[[method_name]], collapse = ', '), ').\n',
            value_test[1]
          ),
          call. = FALSE,
          immediate. = TRUE
        )
      } else if (value_test_fails) {
        if (verbose) {
          cat(paste0(
            "## As expected, test of values failed for ", method_name, ' with wrt: ',
            paste0(param$wrts[[method_name]], collapse = ', '),
            '\n'
          ))
        }
        ## stop testing after an expected failure
        break
      }
      ##
      ## test jacobians
      ##
      jacobian_test_fails <- is_method_failure(
        param$name, method_name, 'jacobian', knownFailures
      )
      jacobian_test <- wrap_if_true(jacobian_test_fails, expect_failure, {
        if (verbose) cat("## Checking jacobians\n")
        expect_equal(
          Cderivs[[method_name]]$jacobian,
          Rderivs[[method_name]]$jacobian,
          tolerance = tol2, info = info
        )
        if ('log' %in% names(opParam$args)) {
          if (verbose) cat("## Checking log behavior for jacobians\n")
          expect_equal(
            Cderivs2[[method_name]]$jacobian,
            Rderivs2[[method_name]]$jacobian,
            tolerance = tol2, info = info
          )
          expect_false(isTRUE(all_equal_ignore_zeros(
            Rderivs[[method_name]]$jacobian,
            Rderivs2[[method_name]]$jacobian,
            tolerance = tol2)), info = info
          )
          expect_false(isTRUE(all_equal_ignore_zeros(
            Cderivs[[method_name]]$jacobian,
            Cderivs2[[method_name]]$jacobian,
            tolerance = tol2)), info = info
          )
        }
      }, wrap_in_try = isTRUE(catch_failures))
      if (isTRUE(catch_failures) && inherits(jacobian_test, 'try-error')) {
        warning(
          paste0(
            'There was something wrong with the jacobian of ',
            opParam$name, ' with wrt = c(',
            paste0(param$wrts[[method_name]], collapse = ', '), ').\n',
            jacobian_test[1]
          ),
          call. = FALSE,
          immediate. = TRUE
        )
      } else if (jacobian_test_fails) {
        if (verbose) {
          cat(paste0(
            "## As expected, test of jacobian failed for ", method_name, ' with wrt: ',
            paste0(param$wrts[[method_name]], collapse = ', '),
            '\n'
          ))
        }
        ## stop testing after an expected failure
        break
      }
      ##
      ## test hessians
      ##
      hessian_test_fails <- is_method_failure(
        param$name, method_name, 'hessian', knownFailures
      )
      hessian_test <- wrap_if_true(hessian_test_fails, expect_failure, {
        if (verbose) cat("## Checking hessians\n")
        expect_equal(
          Cderivs[[method_name]]$hessian,
          Rderivs[[method_name]]$hessian,
          tolerance = tol3, info = info
        )
        if ('log' %in% names(opParam$args)) {
          if (verbose) cat("## Checking log behavior for hessians\n")
          expect_equal(
            Cderivs2[[method_name]]$hessian,
            Rderivs2[[method_name]]$hessian,
            tolerance = tol3, info = info
          )
          expect_false(isTRUE(all_equal_ignore_zeros(
            Rderivs[[method_name]]$hessian,
            Rderivs2[[method_name]]$hessian,
            tolerance = tol3)), info = info
          )
          expect_false(isTRUE(all_equal_ignore_zeros(
            Cderivs[[method_name]]$hessian,
            Cderivs2[[method_name]]$hessian,
            tolerance = tol3)), info = info
          )
        }
      }, wrap_in_try = isTRUE(catch_failures))
      if (isTRUE(catch_failures) && inherits(hessian_test, 'try-error')) {
        warning(
          paste0(
            'There was something wrong with the hessian of ',
            opParam$name, ' with wrt = c(',
            paste0(param$wrts[[method_name]], collapse = ', '), ').\n',
            hessian_test[1]
          ),
          call. = FALSE,
          immediate. = TRUE
        )
      } else if (hessian_test_fails) {
        if (verbose) {
          cat(paste0(
            "## As expected, test of hessian failed for ", method_name, ' with wrt: ',
            paste0(param$wrts[[method_name]], collapse = ', '),
            '\n'
          ))
        }
        ## stop testing after an expected failure
        break
      }
    }
  }
  if (verbose) cat("### Test successful \n\n")
  if (return_compiled_nf)
    invisible(list(CnfInst = CnfInst, input = input))
  else if(!compilation_fails) {
    nimble:::clearCompiled(CnfInst)
    invisible(list(input = input))
  }
  invisible(NULL)
}

test_AD_batch <- function(batch, dir = file.path(tempdir(), "nimble_generatedCode"),
                          control = list(), verbose = nimbleOptions('verbose'),
                          catch_failures = FALSE, seed = 0,
                          nimbleProject_name = '', knownFailures = list(),
                          testFun = test_AD) {
  ## could try to do something more clever here, like putting the entire batch
  ## into one giant nimbleFunction generator so we only have to compile once
  ## (perhaps conditional on having no knownFailures in the entire batch)
  lapply(
    batch, testFun,
    dir, control, verbose, catch_failures, seed,
    nimbleProject_name, FALSE, knownFailures
  )
  invisible(NULL)
}

#########################################
## AD test parameterization builder utils
#########################################

## Takes a named list of `argTypes` and returns a list of character
## vectors, each of which is valid as the `wrt` argument of `nimDerivs()`.
## Each argument on its own and all combinations of the arguments will
## always be included, and then make_wrt will try to create up to `n_random`
## additional character vectors with random combinations of the arguments
## and indexing of those arguments when possible (i.e. for non-scalar args).
## n_arg_reps determines how many times an argument can be used in a given
## wrt character vector. By default, any argument will appear only 1 time.
make_wrt <- function(argTypes, n_random = 10, n_arg_reps = 1) {

  ## always include each arg on its own, and all combinations of the args
  wrts <- as.list(names(argTypes))
  if (length(argTypes) > 1)
    for (m in 2:length(argTypes)) {
      this_combn <- combn(names(argTypes), m)
      wrts <- c(
        wrts,
        unlist(apply(this_combn, 2, list), recursive = FALSE)
      )
    }

  argSymbols <- lapply(
    argTypes, function(argType)
      add_missing_size(nimble:::argType2symbol(argType))
  )

  while (n_random > 0) {
    n_random  <- n_random - 1
    n <- sample(1:length(argTypes), 1) # how many of the args to use?
    ## grab a random subset of the args of length n
    args <- sample(argSymbols, n)
    ## may repeat an arg up to n_arg_reps times
    reps <- sample(1:n_arg_reps, length(args), replace = TRUE)
    this_wrt <- c()
    for (i in 1:length(args)) {
      while (reps[i] > 0) {
        reps[i] <- reps[i] - 1
        ## coin flip determines whether to index vectors/matrices
        use_indexing <- sample(c(TRUE, FALSE), 1)
        if (use_indexing && args[[i]]$nDim > 0) {
          rand_row <- sample(1:args[[i]]$size[1], size = 1)
          ## another coin flip determines whether to use : in indexing or not
          use_colon <- sample(c(TRUE, FALSE), 1)
          if (use_colon && rand_row < args[[i]]$size[1]) {
            end_row <- rand_row +
              sample(1:(args[[i]]$size[1] - rand_row), size = 1)
            rand_row <- paste0(rand_row, ':', end_row)
          }
          index <- rand_row
          if (args[[i]]$nDim == 2) {
            rand_col <- sample(1:args[[i]]$size[2], size = 1)
            ## one more coin flip to subscript second dimension
            use_colon_again <- sample(c(TRUE, FALSE), 1)
            if (use_colon_again && rand_col < args[[i]]$size[2]) {
              end_col <- rand_col +
                sample(1:(args[[i]]$size[2] - rand_col), size = 1)
              rand_col <- paste0(rand_col, ':', end_col)
            }
            index <- paste0(index, ',', rand_col)
          }
          this_wrt <- c(this_wrt, paste0(names(args)[i], '[', index, ']'))
        }
        ## if first coin flip was FALSE, just
        ## use the arg name without indexing
        else this_wrt <- c(this_wrt, names(args)[i])
      }
    }
    if (!is.null(this_wrt)) wrts <- c(wrts, list(unique(this_wrt)))
  }
  unique(wrts)
}

## A version 2 of make_AD_test.
#
# Construct nimbleFunction with following methods:
# value - simply calculate the function, the core of everything.
# derivsValue - get any derivatives of value.
# value2 - get the value from derivsValue. can be used for double (meta) taping.
# jac - get the jacobian from derivsValue via reverse order 1.  can double tape
# jac2 - get the jacobian from derivsValue via forward order 1 (same as en route to hessian). can double tape
# hess - get the hessian from derivsValue. can double tape.
# derivsJac - get any derivatives of jac
# derivsJac2 - get any derivatives of jac2
# derivsHess - get any derivatives of hess.
#
# For later testing (notation is "requested orders" --> "actual orders")
#   e.g. requesting order 0 from derivsJac gives order 1 of the actual function (value).
# run: gives 0
# derivsRun: 0, 1, 2 --> 0, 1, 2
# value: gives 0 (tests F0)
# jac: gives 1    (tests F0R1)
# jac2: gives 1   (tests F0F1)
# hess: gives 2   (tests F0F1R2)
# derivsValue: 0, 1, 2 --> 0, 1, 2 (tests meta-taped F0 for later derivs)
# derivsJac: 0, 1 --> 1, 2         (tests meta-taped F0R1 for later derivs)
# derivsJac2: 0, 1 --> 1, 2        (tests meta-taped F0F1 for later derivs)
# derivsHess: 0 --> 2              (tests meta-taped F0F0R2 for later derivs)
make_AD_test2 <- function(op, argTypes, wrt_args = NULL,
                          input_gen_funs = NULL, more_args = NULL, seed = 0,
                          outer_code = NULL, inner_codes = NULL,
                          size = NULL, inputs = NULL,
                          includeModelArgs = FALSE) {
  if(!is.list(op)) {
    opParam <- make_op_param(op, argTypes, more_args = more_args,
                             outer_code = outer_code,  inner_codes = inner_codes)
  } else {
    opParam <- op
  }
  run <- gen_runFunCore(opParam)

  if(seed == 0) seed <- round(runif(1, 1, 10000))

  modelArg <- NULL
  updateNodesArg <- NULL
  constantNodesArg <- NULL
  if(isTRUE(includeModelArgs)) {
    # If this is used, then some custom setup code
    # will need to be inserted to create model, updateNodes, and constantNodes
    modelArg <- as.name("model")
    updateNodesArg <- as.name("updateNodes")
    constantNodesArg <- as.name("constantNodes")
  }


  ## Make a set of methods. These need to be constructed
  ## so that they can have different argument names, numbers, and types.
  args_formals <- lapply(argTypes, function(argType) {
    parse(text = argType)[[1]]
  })
  if (is.null(names(args_formals)))
    names(args_formals) <- paste0('arg', 1:length(args_formals))

  runCall <- parse(text = paste0("run(", paste(names(args_formals), collapse=","), ")"),
                   keep.source = FALSE)[[1]]

  derivsRun <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      wrt = double(1),
      order = integer(1),
      reset = logical(0, default = FALSE)) {
      ans <- nimDerivs(RUNCALL, wrt = wrt, order = order, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      return(ans)
      returnType(ADNimbleList())
    },
    list(RUNCALL = runCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(derivsRun)$srcref <- NULL
  formals(derivsRun) <- c(args_formals, formals(derivsRun))

  value <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      wrt = double(1),
      reset = logical(0, default = FALSE)) {
      ans <- nimDerivs(RUNCALL, wrt = wrt, order = 0, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      d1 <- dim(ans$value)[1]
      res <- numeric(length = d1)
      for(i in 1:d1)
        res[i] <- ans$value[i]
      return(res)
      returnType(double(1))
    },
    list(RUNCALL = runCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(value)$srcref <- NULL
  formals(value) <- c(args_formals, formals(value))

  jac <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      wrt = double(1),
      reset = logical(0, default = FALSE)) {
      ans <- nimDerivs(RUNCALL, wrt = wrt, order = 1, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      d1 <- dim(ans$jacobian)[1]
      d2  <- dim(ans$jacobian)[2]
      res <- numeric(length = d1*d2)
      for(i in 1:d1)
        for(j in 1:d2)
          res[i + (j-1)*d1] <- ans$jacobian[i, j]
      return(res)
      returnType(double(1))
    },
    list(RUNCALL = runCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(jac)$srcref <- NULL
  formals(jac) <- c(args_formals, formals(jac))

  jac2 <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      wrt = double(1),
      reset = logical(0, default = FALSE)) {
      # Because this gets order 2, the order 1 comes from forward 1
      ans <- nimDerivs(RUNCALL, wrt = wrt, order = 1:2, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      d1 <- dim(ans$jacobian)[1]
      d2  <- dim(ans$jacobian)[2]
      res <- numeric(length = d1*d2)
      for(i in 1:d1)
        for(j in 1:d2)
          res[i + (j-1)*d1] <- ans$jacobian[i, j]
      return(res)
      returnType(double(1))
    },
    list(RUNCALL = runCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(jac2)$srcref <- NULL
  formals(jac2) <- c(args_formals, formals(jac2))

  hess <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      wrt = double(1),
      reset = logical(0, default = FALSE)) {
      ans <- nimDerivs(RUNCALL, wrt = wrt, order = 2, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      d1 <- dim(ans$hessian)[1]
      d2 <- dim(ans$hessian)[2]
      d3 <- dim(ans$hessian)[3]
      res <- numeric(length = d1*d2*d3)
      # There was a bug in pulling out results in a single line.
      # It is unrelated to AD so considered separate.
      for(i in 1:d1)
        for(j in 1:d2)
          for(k in 1:d3)
            res[i + (j-1)*d1 + (k-1)*d1*d2] <- ans$hessian[i, j, k]
      return(res)
      returnType(double(1))
    },
    list(RUNCALL = runCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(hess)$srcref <- NULL
  formals(hess) <- c(args_formals, formals(hess))

  valueCall <- parse(text = paste0("value(",
                                   paste(c(names(args_formals),
                                           "wrt = innerWrt",
                                           "reset=reset"), collapse=","),
                                   ")"),
                     keep.source = FALSE)[[1]]

  derivsValue <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      innerWrt = double(1),
      wrt = double(1),
      order = integer(1),
      reset = logical(0, default = FALSE) ) {
      ans <- nimDerivs(VALUECALL, wrt = wrt, order = order, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      return(ans)
      returnType(ADNimbleList())
    },
    list(VALUECALL = valueCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(derivsValue)$srcref <- NULL
  formals(derivsValue) <- c(args_formals, formals(derivsValue))


  jacCall <- parse(text = paste0("jac(",
                                 paste(c(names(args_formals),
                                         "wrt = innerWrt",
                                         "reset=reset"), collapse=","),
                                 ")"),
                   keep.source = FALSE)[[1]]

  derivsJac <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      innerWrt = double(1),
      wrt = double(1),
      order = integer(1),
      reset = logical(0, default = FALSE) ) {
      ans <- nimDerivs(JACCALL, wrt = wrt, order = order, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      return(ans)
      returnType(ADNimbleList())
    },
    list(JACCALL = jacCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(derivsJac)$srcref <- NULL
  formals(derivsJac) <- c(args_formals, formals(derivsJac))

  jac2Call <- parse(text = paste0("jac2(",
                                  paste(c(names(args_formals),
                                          "wrt = innerWrt",
                                          "reset=reset"), collapse=","),
                                  ")"),
                    keep.source = FALSE)[[1]]

  derivsJac2 <- eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      innerWrt = double(1),
      wrt = double(1),
      order = integer(1),
      reset = logical(0, default = FALSE) ) {
      ans <- nimDerivs(JAC2CALL, wrt = wrt, order = order, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      return(ans)
      returnType(ADNimbleList())
    },
    list(JAC2CALL = jac2Call, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(derivsJac2)$srcref <- NULL
  formals(derivsJac2) <- c(args_formals, formals(derivsJac2))


  hessCall <- parse(text = paste0("hess(",
                                  paste(c(names(args_formals),
                                          "wrt = innerWrt",
                                          "reset=reset"), collapse=","),
                                  ")"),
                    keep.source = FALSE)[[1]]
  derivsHess  <-  eval(substitute(
    function(#arg1 = double(1) etc to be inserted
      innerWrt = double(1),
      wrt = double(1),
      order = integer(1),
      reset = logical(0, default = FALSE) ) {
      ans <- nimDerivs(HESSCALL, wrt = wrt, order = order, reset = reset,
                       model = MODELARG, updateNodes = UPDATENODESARG,
                       constantNodes = CONSTANTNODESARG)
      return(ans)
      returnType(ADNimbleList())
    },
    list(HESSCALL = hessCall, MODELARG = modelArg,
         UPDATENODESARG = updateNodesArg,
         CONSTANTNODESARG = constantNodesArg)
  ))
  attributes(derivsHess)$srcref <- NULL
  formals(derivsHess) <- c(args_formals, formals(derivsHess))

  methods <- list(derivsRun = derivsRun,
                  value = value,
                  jac = jac,
                  jac2 = jac2,
                  hess = hess,
                  derivsValue = derivsValue,
                  derivsJac = derivsJac,
                  derivsJac2 = derivsJac2,
                  derivsHess = derivsHess
  )

  list(
    name = opParam$name,
    opParam = opParam,
    run = run,
    methods = methods,
    buildDerivs = list(run = list()
                       ,
                       value = list(ignore = c('wrt', 'i')),
                       jac = list(ignore = c('wrt', 'i', 'j')),
                       jac2 = list(ignore = c('wrt', 'i', 'j')),
                       hess = list(ignore = c('wrt', 'i', 'j', 'k'))
    ),
    wrt_args = wrt_args,
    input_gen_funs = input_gen_funs,
    size = size,
    seed = seed,
    inputs = inputs
  )
}


model_calculate_test_pieces <- make_AD_test2(
  op = list(
    expr = quote( {
      values(model, derivNodes) <<- arg1
      out <- model$calculate(calcNodes)
      return(out)
    }),
    args = list(arg1 = quote(double(1))),
    outputType = quote(double())
  ),
  argTypes = list(arg1 = "double(1)"),
  includeModelArgs = TRUE)

model_calculate_test <- nimbleFunction(
  setup = function(model, nodesList) {
    derivNodes <- nodesList$derivNodes
    updateNodes <- nodesList$updateNodes
    constantNodes <- nodesList$constantNodes
    calcNodes <- nodesList$calcNodes
    nNodes <- length(derivNodes)
  },
  run = model_calculate_test_pieces$run,
  methods = model_calculate_test_pieces$methods,
  buildDerivs = model_calculate_test_pieces$buildDerivs
)

setup_update_and_constant_nodes_for_tests <- function(model,
                                                      derivNodes,
                                                      forceConstantNodes = character(),
                                                      forceUpdateNodes = character()) {
  ## "update" means "CppAD dynamic"
  ##  derivNodes <- model$expandNodeNames(derivNodes) # do not do this because do not want vector node names
  nNodes <- length(derivNodes)
  calcNodes <- model$getDependencies(derivNodes)
  ucNodes <- makeModelDerivsInfo(model, derivNodes, calcNodes, dataAsConstantNodes = TRUE)
  updateNodes <- ucNodes$updateNodes
  constantNodes <- ucNodes$constantNodes
  updateNodes <- setdiff(updateNodes, forceConstantNodes) # remove forceConstants from updates
  constantNodes <- setdiff(constantNodes, forceUpdateNodes) # remove forceUpdates from constants
  constantNodes <- union(constantNodes, forceConstantNodes) # add forceConstants to constants
  updateNodes <- union(updateNodes, forceUpdateNodes)     # add forceUpdates to updates
  list(derivNodes = derivNodes, updateNodes = updateNodes,
       constantNodes = constantNodes, calcNodes = calcNodes)
}

model_calculate_test_case <- function(Rmodel, Cmodel,
                                      deriv_nf, nodesList,
                                      v1, v2  ,
                                      order,
                                      varValues = list(),
                                      varValues2 = list(), ...) {
  # This sets up an instance of a checker fxn
  Rfxn <- deriv_nf( Rmodel, nodesList)
  Cfxn <- compileNimble(Rfxn, project = Rmodel)

  test_AD2_oneCall(Rfxn, Cfxn,
                   recordArgs = v1, testArgs = v2,
                   order = order,
                   Rmodel = Rmodel, Cmodel = Cmodel,
                   recordInits = varValues, testInits = varValues2,
                   nodesToChange = c(nodesList$updateNodes),
                   ...)
}

## Make a test parameterization to be used by test_AD. This method is primarily
## used by make_AD_test_batch() and make_distribution_fun_AD_test().
##
## op:             Character string, the operator that will be the focus of the
##                 test.
## argTypes:       Character vector of argType strings that, when parsed, can be
##                 passed to argType2symbol. If named, the names will as the formals
##                 of the nimbleFunction generator's run method and other methods.
##                 If not, formals are generated as arg1, arg2, etc.
## wrt_args:       Optional character vector of args to use in make_wrt(). If NULL,
##                 assumes that all the arguments should be used.
## input_gen_funs: A list of input generation functions which is simply passed to
##                 the output list. Should be NULL (use defaults found in
##                 arg_type_2_input()), length 1 (use same input gen mechanism for each
##                 argType, or a named list with names from among the argType names
##                 (possibly the sequentially generated names). This will be NULL
##                 when bulk generating the test params using make_AD_test_batch and
##                 added later via modify_on_match(). Used in the call to
##                 make_AD_test() in make_distribution_fun_AD_test().
## more_args:      A named list of additional fixed arguments to use in the
##                 generated operator call. E.g., if op = 'dnorm',
##                 argTypes = c('double(1, 5)', 'double(0)'), and
##                 more_args = list(log = 1), the call to make_op_param will include
##                 the expression dnorm(arg1, arg2, log = 1).
## seed:           A seed to use in set.seed().
##
## returns: A list with the following elements:
##          name:           character string, the operator and args
##          opParam:        result from make_op_param
##          run:            result from calling gen_runFunCore with opParam
##          methods:        a list of nimbleFunction method expressions that are
##                          the calls to nimDerivs() of the run method, each with
##                          a different wrt argument
##          buildDerivs:   list('run')
##          wrts:           a list of character vectors, each of which is the wrt
##                          argument for the corresponding method in methods
##          input_gen_funs: A list of random input generation functions to be
##                          used by arg_type_2_input().
make_AD_test <- function(op, argTypes, wrt_args = NULL,
                         input_gen_funs = NULL, more_args = NULL, seed = 0, ...) {
  ## set the seed for make_wrt
  if (is.numeric(seed)) set.seed(seed)
  opParam <- make_op_param(op, argTypes, more_args)

  run <- gen_runFunCore(opParam)
  method <- function() {
    {}
    returnType(ADNimbleList())
    return(outList)
  }
  formals_list <- lapply(argTypes, function(argType) {
    parse(text = argType)[[1]]
  })
  if (is.null(names(formals_list)))
    names(formals_list) <- paste0('arg', 1:length(formals_list))
  formals(method) <- formals_list

  body(method)[[2]] <-
    substitute(
      outList <- derivs(this_call),
      list(
        this_call = as.call(
          c(quote(run), lapply(names(formals_list), as.name))
        )
      )
    )

  methods <- list()

  if (is.null(wrt_args)) wrt_args_filter <- rep(TRUE, length(argTypes))
  else wrt_args_filter <- wrt_args
  wrts <- make_wrt(opParam$args[wrt_args_filter])
  for (i in seq_along(wrts)) {
    method_i <- paste0('method', i)
    methods[[method_i]] <- method
    body(methods[[method_i]])[[2]][[3]][['wrt']] <- wrts[[i]]
  }

  if (is.null(wrt_args) || all(names(opParam$args) %in% wrt_args)) {
    method_no_wrt <- paste0('method', length(methods) + 1)
    methods[[method_no_wrt]] <- method
    wrts <- c(wrts, 'no wrt')
  }

  names(wrts) <- names(methods)

  ## method_all_wrt <- paste0('method', length(wrt) + 2)
  ## wrts[[method_all_wrt]] <- paste0('wrt ', paste0(wrt, collapse = ', '))
  ## methods[[method_all_wrt]] <- methods[[method_no_wrt]] <-  method
  ## body(methods[[method_all_wrt]])[[2]][[3]][['wrt']] <- wrt

  list(
    name = opParam$name,
    opParam = opParam,
    run = run,
    methods = methods,
    buildDerivs = 'run',
    wrts = wrts,
    input_gen_funs = input_gen_funs
  )
}

## ops: character vector of operator names
## argTypes: list of character vectors of argTypes
##           e.g. for a binary operator:
##             list(
##               c('double(1, 4)', 'double(0)'),
##               c('double(1, 4)', 'double(1, 4)')
##             )
make_AD_test_batch <- function(ops, argTypes, seed = 0,
                               maker = make_AD_test,
                               outer_code = NULL, inner_codes = NULL) {
  opTests <- vector(mode = 'list', length = length(ops) * length(argTypes))
  for(i in seq_along(ops)) {
    for(j in seq_along(argTypes)) {
      iOut <- (i-1) * length(argTypes) + j
      opTests[[iOut]] <- maker(op = ops[[i]], argTypes = argTypes[[j]],
                               seed = seed, outer_code = outer_code, inner_codes = inner_codes)
    }
  }

  ## opTests <- unlist(
  ##   recursive = FALSE,
  ##   x = lapply(
  ##     ops,
  ##     function(x) {
  ##       mapply(
  ##         maker,
  ##         argTypes = argTypes,
  ##         MoreArgs = list(op = x, seed = seed,
  ##                         outer_code = outer_code, inner_codes = inner_codes),
  ##         SIMPLIFY = FALSE
  ##       )
  ##     })
  ## )
  names(opTests) <- sapply(opTests, `[[`, 'name')
  invisible(opTests)
}

## Takes an element of distn_params list and returns a list of AD test
## parameterizations, each of which test_AD can use.
##
## distn_param: Probability distribution parameterization, which
##              must have the following fields/subfields:
##              - name
##              - variants
##              - args
##                - rand_variate
##                  - type
##              Additional args must also have type field.
## more_args:   Passed to make_op_param().
##
make_distribution_fun_AD_test <- function(distn_param, maker = make_AD_test) {
  distn_name <- distn_param$name
  ops <- sapply(distn_param$variants, paste0, distn_name, simplify = FALSE)

  rand_variate_idx <- which(names(distn_param$args) == 'rand_variate')

  argTypes <- sapply(distn_param$variants, function(variant) {
    op <- paste0(variant, distn_name)
    rand_variate_type <- distn_param$args$rand_variate$type
    first_argType <- switch(
      variant,
      d = rand_variate_type,
      p = rand_variate_type,
      q = c('double(0)')##, 'double(1, 4)')
    )
    first_arg_name <- switch(variant, d = 'x', p = 'q', q = 'p')
    ## need this complicated expand.grid call here because the argTypes might be
    ## character vectors (e.g. if rand_variate_type is c("double(0)", "double(1, 4)")
    ## then we create a test where we sample a scalar from the support and
    ## another where we sample a vector of length 4 from the support
    grid <- eval(as.call(c(
      expand.grid, list(first_argType),
      lapply(distn_param$args, `[[`, 'type')[-rand_variate_idx]
    )))
    argTypes <- as.list(data.frame(t(grid), stringsAsFactors=FALSE))
    lapply(argTypes, function(v) {
      names(v) <- c(
        first_arg_name,
        names(distn_param$args[-rand_variate_idx])
      )
      v
    })
  }, simplify = FALSE)
  for(i in seq_along(distn_param$args)) {
    if(is.null(distn_param$args[[i]]$size))
      distn_param$args[[i]]$size <- lapply(distn_param$args[[i]]$type,
                                           function(type) {
                                             if(is.character(type))
                                               type <- parse(text = type, keep.source=FALSE)[[1]]
                                             add_missing_size(nimble:::argType2symbol(type))$size
                                           })
  }
  argSizes <- sapply(distn_param$variants, function(variant) {
    rand_variate_size <- distn_param$args$rand_variate$size
    first_argSize <- switch(
      variant,
      d = rand_variate_size,
      p = rand_variate_size,
      q = 1
    )
    first_arg_name <- switch(variant, d = 'x', p = 'q', q = 'p')
    grid <- eval(as.call(c(
      expand.grid, list(first_argSize),
      lapply(distn_param$args, `[[`, 'size')[-rand_variate_idx]
    )))
    argSizes <- as.list(data.frame(t(grid), stringsAsFactors=FALSE))
    lapply(argSizes, function(v) {
      names(v) <- c(
        first_arg_name,
        names(distn_param$args[-rand_variate_idx])
      )
      v
    })
  }, simplify = FALSE)
  input_gen_funs <- lapply(distn_param$args[-rand_variate_idx], `[[`, 'input_gen_fun')
  input_gen_funs_list <- sapply(distn_param$variants, function(variant) {
    arg1_input_gen_fun <- switch(
      variant,
      d = list(x = distn_param$args$rand_variate$input_gen_fun),
      p = list(q = distn_param$args$rand_variate$input_gen_fun),
      q = list(p = runif)
    )
    c(arg1_input_gen_fun, input_gen_funs)
  }, simplify = FALSE)

  op_params <- unlist(lapply(distn_param$variants,
                             function(variant) {
                               ans <- list()
                               for(i in seq_along(argTypes[[variant]])) {
                                 these_argTypes <- argTypes[[variant]][[i]]
                                 these_argSizes  <- argSizes[[variant]][[i]]
                                 wrt_args = intersect(
                                   distn_param$wrt, names(these_argTypes)
                                 )
                                 new_ans <-  maker(
                                   ops[[variant]], these_argTypes,
                                   wrt_args = wrt_args,
                                   input_gen_funs = input_gen_funs_list[[variant]],
                                   more_args = distn_param$more_args[[variant]],
                                   size = these_argSizes
                                 )
                                 ans[[length(ans)+1]]<-new_ans
                               }
                               ans
                             }), recursive = FALSE)

  ## op_params <- unlist(
  ##   lapply(
  ##     distn_param$variants, function(variant) {
  ##       lapply(
  ##         argTypes[[variant]],
  ##         function(these_argTypes) {
  ##           wrt_args = intersect(
  ##             distn_param$wrt, names(these_argTypes)
  ##           )
  ##           maker(
  ##             ops[[variant]], these_argTypes,
  ##             wrt_args = wrt_args,
  ##             input_gen_funs = input_gen_funs_list[[variant]],
  ##             more_args = distn_param$more_args[[variant]]
  ##           )
  ##         }
  ##       )
  ##     }
  ##   ), recursive = FALSE
  ## )
  names(op_params) <- sapply(op_params, `[[`, 'name')
  return(op_params)
}

#############################
## input generation functions
#############################

## arg_size comes from arg$size where arg is a symbolBasic object
gen_pos_def_matrix <- function(arg_size) {
  m <- arg_size[1] ## assumes matrix argType is square
  if(length(arg_size) == 1) # single length defined as arg_size = m^2
    m <- sqrt(m)
  mat <- diag(m)
  mat[lower.tri(mat, diag = TRUE)] <- runif(m*(m + 1)/2)
  mat %*% t(mat)
}

Try the nimbleEcology package in your browser

Any scripts or data that you put into this service are public.

nimbleEcology documentation built on June 27, 2024, 5:09 p.m.