R/nimbleFunction_Rderivs.R

Defines functions convertWrtArgToIndices nimDerivs_nf nimDerivs_nf_numericwrt nimDerivs_nf_charwrt nimDerivs_calcNodes nimDerivs_model calcDerivs_internal calcDerivs_hessian nimDerivs derivInfo

Documented in nimDerivs

derivInfo <- function(x, ...) x

## dummy for recursing in sizeNimDerivs
nimDerivs_dummy <- nimbleFunction(
    run = function(call = void(),
                   order = integer(1),
                   dropArgs = void(),
                   wrt = integer(1)) {
    }
)

#' Nimble Derivatives
#' 
#' Computes the value, 1st order (Jacobian), and 2nd order (Hessian) derivatives of a given
#' \code{nimbleFunction} method and/or model log probabilities
#' 
#' @param call a call to a \code{nimbleFunction} method with arguments
#' included.  Can also be a call to  \code{model$calculate(nodes)}, or to 
#' \code{calculate(model, nodes)}.
#' @param wrt a character vector of either: names of function arguments 
#' (if taking derivatives of a \code{nimbleFunction} method), or node names 
#' (if taking derivatives of \code{model$calculate(nodes)}) to take derivatives 
#' with respect to.  If left empty, derivatives will be taken with respect to 
#' all arguments to \code{nimFxn}.
#' @param order an integer vector with values within the set {0, 1, 2}, 
#' corresponding to whether the function value, Jacobian, and Hessian should be
#'  returned respectively.  Defaults to \code{c(0, 1, 2)}.
#' @param model (optional) for derivatives of a nimbleFunction that involves model.
#' calculations, the uncompiled model that is used. This is needed in order
#' to be able to correctly restore values into the model when \code{order} does not
#' include 0 (or in all cases when double-taping).
#' @param ... additional arguments intended for internal use only.
#'
#'@details Derivatives for uncompiled nimbleFunctions are calculated using the
#' \code{numDeriv} package.  If this package is not installed, an error will
#' be issued.  Derivatives for matrix valued arguments will be returned in 
#' column-major order.
#' 
#' @return a \code{nimbleList} with elements \code{value}, \code{jacobian},
#' and \code{hessian}.
#'
#' @aliases derivs AD
#' 
#' @examples 
#' 
#' \dontrun{
#' model <- nimbleModel(code = ...)
#' calcDerivs <- nimDerivs(model$calculate(model$getDependencies('x')),
#'  wrt = 'x')
#' }
#' 
#' @export
nimDerivs <- function(call = NA,
                      wrt = NULL,
                      order = nimC(0,1,2),
                      model = NA,
                      ...){ ## ... absorbs compile-only params
  fxnEnv <- parent.frame()
  fxnCall <- match.call()
  if(is.null(fxnCall[['order']])) fxnCall[['order']] <- order
  derivFxnCall <- fxnCall[['call']]
  ## Below are two checks to see if the nimFxn argument is a model calculate 
  ## call. If so,  a wrapper function will be made and nimDerivs will be called 
  ## again on that wrapper function.
  model_calculate_case <- FALSE
  if(deparse(derivFxnCall[[1]]) == 'calculate'){
      model_calculate_case <- TRUE
  }
  if(length(derivFxnCall[[1]]) == 3 && deparse(derivFxnCall[[1]][[1]]) == '$'){
    if(deparse(derivFxnCall[[1]][[3]]) == 'calculate'){
      modelError <- tryCatch(get('modelDef', pos = eval(derivFxnCall[[1]][[2]],
                                                        envir = fxnEnv)),
                             error = function(e) e)
      if(!inherits(modelError, 'error')){
          model_calculate_case <- TRUE
      }
    }
  }
  ## Handle case of nimFxn argument being a calcNodes call for faster
  ## AD testing of model calculate calls. 
  model_calcNodes_case <- FALSE
  ## 6/20/22 I am checking if this use case is fully deprecated by commenting this out
  ## if(length(derivFxnCall[[1]]) == 3 && deparse(derivFxnCall[[1]][[1]]) == '$'){
  ##     if(deparse(derivFxnCall[[1]][[3]]) == 'run'){
  ##         nf <- eval(derivFxnCall[[1]][[2]], envir = fxnEnv)
  ##         if(is(nf, "calcNodes_refClass")) {
  ##             model <- eval(nf$Robject$model, envir = fxnEnv)
  ##             if(!(exists('CobjectInterface', model) && !is(model$CobjectInterface, 'uninitializedField'))) { ## but, model should always be compiled because existence of calcNodes_refClass implies that.
  ##                 cModel <- compileNimble(model)
  ##             } else cModel <- eval(quote(model$CobjectInterface))
  ##             model_calcNodes_case <- TRUE
  ##         }
  ##     }
  ## }

  ## if(!is.na(dropArgs)){
  ##   removeArgs <- which(wrt == dropArgs)
  ##   if(length(removeArgs) > 0)
  ##     wrt <- wrt[-removeArgs]
  ## }
  if(model_calculate_case) {
      ans <- nimDerivs_model( order = order, wrt = wrt, derivFxnCall = derivFxnCall, fxnEnv = fxnEnv )
  } else {
      if(model_calcNodes_case) {
          stop("nimDerivs: invalid case.")
          ## ans <- nimDerivs_calcNodes( order = order, wrt = wrt, derivFxnCall = derivFxnCall, fxnEnv = fxnEnv, model = cModel, nodes = calcNodes)
      } else ans <- nimDerivs_nf( order = order, wrt = wrt, derivFxnCall = derivFxnCall, fxnEnv = fxnEnv, model = model )
  }
  ans
}

calcDerivs_hessian <- function(func, X, n) { # deriv_package = "pracma"
    if(missing(n)) {
        check <- func(X)
        n <- length(check)
    }
    p <- length(X)
    outHess <- array(NA, dim = c(p, p, n))
    for(i in 1:n) {
        wrapped_func <- function(X) as.numeric(func(X))[i]
        ## We can't import both pracma::hessian and numDeriv::hessian,
        ## as R CMD INSTALL gives
        ## Warning: replacing previous import ‘pracma::hessian’ by ‘numDeriv::hessian’ when loading ‘nimble’
        outHess[, , i] <- pracma::hessian(wrapped_func, X)
        ## if(deriv_package == "pracma")
        ##     outHess[, , i] <- pracma::hessian(wrapped_func, X)
        ## else
        ##     outHess[, , i] <- hessian(wrapped_func, X)
    }
    outHess
}

calcDerivs_internal <- function(func, X, order, resultIndices) { # deriv_package = "pracma"
    deriv_package <- "pracma"  # fixed given can't import both pracma and numDeriv hessian()
    hessianFlag <- 2 %in% order
    jacobianFlag <- 1 %in% order
    ## When called for a model$calculate, valueFlag will always be 0 here
    ## because value will be obtained later (if requested) after restoring
    ## model variables.
    valueFlag <- 0 %in% order
    outVal <- NULL

    if(hessianFlag) {
        ## determine output size
        returnType_scalar <- FALSE
        funcEnv <- environment(func)
        if (deparse(funcEnv$fxnCall[[1]]) == 'nimDerivs_nf') {
            fxnEnv <- funcEnv[['fxnEnv']]
            derivFxnCall <- funcEnv[['derivFxnCall']]
            genFunEnv <- environment(
                environment(
                    eval(derivFxnCall[[1]], envir = fxnEnv)
                )[['.generatorFunction']]
            )
            returnType <- argType2symbol(
                genFunEnv$methodList[[funcEnv$derivFxnName]]$returnType
            )
            returnType_scalar <- all(returnType$size == 1)
        } else {
            outVal <- func(X)
            returnType_scalar <- length(outVal) == 1
        }
        
        if (returnType_scalar) {
            ## We can't import both pracma::hessian and numDeriv::hessian,
            ## as R CMD INSTALL gives
            ## Warning: replacing previous import ‘pracma::hessian’ by ‘numDeriv::hessian’ when loading ‘nimble’
            hess <- pracma::hessian(func, X)
            ## numDeriv's hessian() only works for scalar-valued functions
            ## if(deriv_package == "pracma") {
            ##     hess <- pracma::hessian(func, X)
            ## } else {
            ##     hess <- hessian(func, X)
            ## }
            outHess <- array(NA, dim = c(dim(hess)[1], dim(hess)[2], 1))
            outHess[,,1]  <- hess
            if(jacobianFlag) outGrad <- numDeriv::jacobian(func, X)
            if(valueFlag && is.null(outVal)) outVal <- func(X)
        } else {
            ## If hessians are requested for a non-scalar valued function,
            ## derivatives taken using pracma or numDeriv's genD() function.  After that,
            ## we extract the various derivative elements and arrange them properly.
            if(deriv_package == "pracma") {
                value <- func(X)
                if(valueFlag && is.null(outVal)) outVal <- value
                if(jacobianFlag) outGrad <- numDeriv::jacobian(func, X)
                outHess <- calcDerivs_hessian(func, X, n = length(value)) 
            } else {
                derivList <- numDeriv::genD(func, X)
                if(valueFlag && is.null(outVal)) outVal <- derivList$f0
                if(jacobianFlag) outGrad <- derivList$D[,1:derivList$p, drop = FALSE]
                outHessVals <- derivList$D[,(derivList$p + 1):dim(derivList$D)[2],
                                           drop = FALSE]
                outHess <- array(NA, dim = c(derivList$p, derivList$p, length(derivList$f0)))
                singleDimMat <- matrix(NA, nrow = derivList$p, ncol = derivList$p)
                singleDimMatUpperTriDiag <- upper.tri(singleDimMat, diag = TRUE)
                singleDimMatLowerTriDiag <- lower.tri(singleDimMat)
                for(outDim in seq_along(derivList$f0)){
                    singleDimMat[singleDimMatUpperTriDiag] <- outHessVals[outDim,]
                    singleDimMat[singleDimMatLowerTriDiag] <- t(singleDimMat)[singleDimMatLowerTriDiag]
                    outHess[,,outDim] <- singleDimMat
                }
            }
        }
    } else if(jacobianFlag){
        ## If jacobians are requested, derivatives taken using numDeriv's jacobian()
        ## function.
        if(valueFlag) outVal <- func(X)
        outGrad <- numDeriv::jacobian(func, X)
    } else if(valueFlag)
        outVal <- func(X)

    ## Run func one last time to make sure that if model calculations are in the method whose
    ## derivative is taken that the values in the model are from the input and not values
    ## from the finite element calculations.
    func(X)

    
    outList <- ADNimbleList$new()
    if(!missing(resultIndices)) {
        if(jacobianFlag) outGrad <- outGrad[, resultIndices, drop=FALSE]
        if(hessianFlag) outHess <- outHess[resultIndices, resultIndices, , drop=FALSE]
    }
    if(valueFlag) outList$value <- outVal
    if(jacobianFlag) outList$jacobian <- outGrad
    if(hessianFlag) outList$hessian <- outHess
    return(outList)
}

nimDerivs_model <- function( nimFxn, order = c(0, 1, 2), wrt = NULL, derivFxnCall = NULL, fxnEnv = parent.frame()) {
    fxnCall <- match.call()
    ## This may be called directly (for testing) or from nimDerivs (typically).
    ## In the former case, we get derivFxnCall from the nimFxn argument.
    ## In the latter case, it is already match.call()ed and is passed here via derivFxnCall
    if(is.null(derivFxnCall))
        derivFxnCall <- fxnCall[['nimFxn']] ## either calculate(model, nodes) or model$calculate(nodes)
    else
        if(is.null(fxnCall[['order']]))
            fxnCall[['order']] <- order

    valueFlag <- 0 %in% order
    
    if(deparse(derivFxnCall[[1]]) == 'calculate'){
        ## calculate(model, nodes) format
        derivFxnCall <- match.call(calculate, derivFxnCall)
        model <- eval(derivFxnCall[['model']], envir = fxnEnv)
        nodes <- eval(derivFxnCall[['nodes']], envir = fxnEnv)
        if(is.null(nodes))
            nodes <- model$getMaps('nodeNamesLHSall')
    } else {
        ## model$calculate(nodes) format, where nodes could be missing
        model <-  eval(derivFxnCall[[1]][[2]], envir = fxnEnv)
        if(length(derivFxnCall) < 2){ ## nodes is missing: default to all nodes
          nodes <- model$getMaps('nodeNamesLHSall')
        } else {
          nodes <- eval(derivFxnCall[[2]], envir = fxnEnv)
        }
    }
    ## We rely on fact that model$expandNodeNames will return column-major order.
    ## That makes it suitable for the canonical ordering of any blocks in wrt.

    ## We need to expandNodeNames (among other reasons) to determine
    ## any duplicates in wrt, such as "x" and "x[2]".
    
    ## It might also be possible to make use of values() and values()<- here.
    ## However, we instead are using a single model utility function (expandNodeNames)
    ## and basing every step on the result of that.  The hope is this will
    ## minimize risk of errors in case values() would somehow handle wrt in
    ## a different order than expandNodeNames.  (It shouldn't, but values() and values()<-
    ## have implementations that are hard to follow.)

    wrt_unique_names <- unique(.Call(parseVar, wrt))
    if(!all(wrt_unique_names %in% model$getVarNames())){
        stop('Error:  the wrt argument to nimDerivs() contains variable names that are not
         in the model.')
    }
    
    ## scalar node elements, in order for results
    wrt_expanded <- model$expandNodeNames(wrt, returnScalarComponents = TRUE, unique = FALSE)
    ## unique scalar node elements
    wrt_expanded_unique <- unique(wrt_expanded)
    ## indices of original node elements in unique elements, for construction of results after getting derivatives wrt unique elements
    wrt_orig_indices <- match(wrt_expanded, wrt_expanded_unique)
    
    ## List of lines like "model$beta[2] <- x[i]", where "beta[2]" is an expanded wrt node and x[i] is a variable inside func.
    wrt_expanded_unique_assign_code <- mapply(
        function(node, I)
            parse(text = paste0("model$", node, " <- x[", I, "]"), keep.source = FALSE)[[1]],
        wrt_expanded_unique,
        seq_along(wrt_expanded_unique))
    ## We could consider doing the above step by substitute, but it would take some care to handle nodes with indices.
    
    length_wrt <- length(wrt_expanded_unique_assign_code)
    
    ## func could be made more efficient if we re-block wrt nodes that are in the same variable.
    ## makeVertexNamesFromIndexArray2 might be helpful for that purpose.
    func <- function(x) {
        do.call("{", wrt_expanded_unique_assign_code)
        ## equivalent to:
        ##        for(i in 1:length_wrt) {
        ##            eval(wrt_expanded_unique_assign_code[[i]])
        ##        }
        model$calculate(nodes)
    }
    
    ## make current values in single vector.
    ## This list will have code like `currentX[3] <- model$beta[2]`
    wrt_expanded_unique_get_code <- mapply(
        function(node, I)
            parse(text = paste0("currentX[",I,"] <- model$", node), keep.source = FALSE)[[1]],
        wrt_expanded_unique,
        seq_along(wrt_expanded_unique))
    
    currentX <- numeric(length_wrt)
    do.call("{", wrt_expanded_unique_get_code)
    
    ## Determine what nodes to restore.  We will do this by variable because in many cases
    ## it will be faster to copy and restore a whole variable in one line that to do elements
    ## with one line for each element.
    ##
    ## If value is not being returned, we restore wrt vars, nodes vars, and any logProb vars for nodes
    ## If value is being returned, we recalculate at the end, using func.  This restores any wrt nodes that
    ## may have been perturbed during finite-element derivatives as well as updating all nodes
    if(!valueFlag) {
        wrtVars <- unique(.Call(parseVar, wrt))
        calcNodeVars <- unique(.Call(parseVar, nodes))
                                        # determine any stochastic nodes and include their logProb nodes
        anyStoch_calcNodeVars <- unlist(lapply(calcNodeVars, function(x) model$getVarInfo(x)$anyStoch))
        logProbVars <- if(sum(anyStoch_calcNodeVars)) makeLogProbName( calcNodeVars[anyStoch_calcNodeVars] ) else character()
        varsToRestore <- c(union(wrtVars, calcNodeVars), logProbVars)
        savedVars <- new.env()
    } else {
        varsToRestore <- character()
    }
    for(i in varsToRestore)
        savedVars[[i]] <- model[[i]]
    
    ## outGrad <- jacobian(func, currentX)
    ## We do not want calcDerivs_internal to calculate the value,
    ## because we want to restore variables first.
    if(valueFlag) order <- order[ order != 0 ]
    if(length(order) > 0) {
        ans <- calcDerivs_internal(func, currentX, order, wrt_orig_indices)
    } else ans <- NULL
    for(i in varsToRestore)
        model[[i]] <- savedVars[[i]]
    
    if(valueFlag) {
        ans$value <- func(currentX)
    }
    ans
}

nimDerivs_calcNodes <- function( nimFxn, order = c(0, 1, 2), wrt = NULL, derivFxnCall = NULL, fxnEnv = parent.frame(), model, nodes) {
    ## Handles nimDerivs(calcNodes$run()) case, used in AD testing. This function could be merged
    ## into nimDerivs_model as much of the code is duplicated but given nimDerivs for calculate
    ## is user facing and this is not, it's standalone code for now.
    
    fxnCall <- match.call()
    ## This may be called directly (for testing) or from nimDerivs (typically).
    ## In the former case, we get derivFxnCall from the nimFxn argument.
    ## In the latter case, it is already match.call()ed and is passed here via derivFxnCall
    if(is.null(derivFxnCall))
        derivFxnCall <- fxnCall[['nimFxn']] ## either calculate(model, nodes) or model$calculate(nodes)
    else
        if(is.null(fxnCall[['order']]))
            fxnCall[['order']] <- order

    valueFlag <- 0 %in% order
    
    ## We rely on fact that model$expandNodeNames will return column-major order.
    ## That makes it suitable for the canonical ordering of any blocks in wrt.

    ## We need to expandNodeNames (among other reasons) to determine
    ## any duplicates in wrt, such as "x" and "x[2]".
    
    ## It might also be possible to make use of values() and values()<- here.
    ## However, we instead are using a single model utility function (expandNodeNames)
    ## and basing every step on the result of that.  The hope is this will
    ## minimize risk of errors in case values() would somehow handle wrt in
    ## a different order than expandNodeNames.  (It shouldn't, but values() and values()<-
    ## have implementations that are hard to follow.)

    wrt_unique_names <- unique(.Call(parseVar, wrt))
    if(!all(wrt_unique_names %in% model$getVarNames())){
        stop('Error:  the wrt argument to nimDerivs() contains variable names that are not
         in the model.')
    }
    
    ## scalar node elements, in order for results
    wrt_expanded <- model$expandNodeNames(wrt, returnScalarComponents = TRUE, unique = FALSE)
    ## unique scalar node elements
    wrt_expanded_unique <- unique(wrt_expanded)
    ## indices of original node elements in unique elements, for construction of results after getting derivatives wrt unique elements
    wrt_orig_indices <- match(wrt_expanded, wrt_expanded_unique)
    
    ## List of lines like "model$beta[2] <- x[i]", where "beta[2]" is an expanded wrt node and x[i] is a variable inside func.
    wrt_expanded_unique_assign_code <- mapply(
        function(node, I)
            parse(text = paste0("model$", node, " <- x[", I, "]"), keep.source = FALSE)[[1]],
        wrt_expanded_unique,
        seq_along(wrt_expanded_unique))
    ## We could consider doing the above step by substitute, but it would take some care to handle nodes with indices.
    
    length_wrt <- length(wrt_expanded_unique_assign_code)
    
    ## func could be made more efficient if we re-block wrt nodes that are in the same variable.
    ## makeVertexNamesFromIndexArray2 might be helpful for that purpose.
    func <- eval(substitute(function(x) {
        do.call("{", wrt_expanded_unique_assign_code)
        ## equivalent to:
        ##        for(i in 1:length_wrt) {
        ##            eval(wrt_expanded_unique_assign_code[[i]])
        ##        }
        CALCCALL
    }, list(CALCCALL = derivFxnCall)))
    
    ## make current values in single vector.
    ## This list will have code like `currentX[3] <- model$beta[2]`
    wrt_expanded_unique_get_code <- mapply(
        function(node, I)
            parse(text = paste0("currentX[",I,"] <- model$", node), keep.source = FALSE)[[1]],
        wrt_expanded_unique,
        seq_along(wrt_expanded_unique))
    
    currentX <- numeric(length_wrt)
    do.call("{", wrt_expanded_unique_get_code)
    
    ## Determine what nodes to restore.  We will do this by variable because in many cases
    ## it will be faster to copy and restore a whole variable in one line that to do elements
    ## with one line for each element.
    ##
    ## If value is not being returned, we restore wrt vars, nodes vars, and any logProb vars for nodes
    ## If value is being returned, we recalculate at the end, using func.  This restores any wrt nodes that
    ## may have been perturbed during finite-element derivatives as well as updating all nodes
    if(!valueFlag) {
        wrtVars <- unique(.Call(parseVar, wrt))
        calcNodeVars <- unique(.Call(parseVar, nodes))
                                        # determine any stochastic nodes and include their logProb nodes
        anyStoch_calcNodeVars <- unlist(lapply(calcNodeVars, function(x) model$getVarInfo(x)$anyStoch))
        logProbVars <- if(sum(anyStoch_calcNodeVars)) makeLogProbName( calcNodeVars[anyStoch_calcNodeVars] ) else character()
        varsToRestore <- c(union(wrtVars, calcNodeVars), logProbVars)
        savedVars <- new.env()
    } else {
        varsToRestore <- character()
    }
    for(i in varsToRestore)
        savedVars[[i]] <- model[[i]]
    
    ## outGrad <- jacobian(func, currentX)
    ## We do not want calcDerivs_internal to calculate the value,
    ## because we want to restore variables first.
    if(valueFlag) order <- order[ order != 0 ]
    if(length(order) > 0) {
        ans <- calcDerivs_internal(func, currentX, order, wrt_orig_indices)
    } else ans <- NULL
    for(i in varsToRestore)
        model[[i]] <- savedVars[[i]]
    
    if(valueFlag) {
        ans$value <- func(currentX)
    }
    ans
}

nimDerivs_nf_charwrt <- function(derivFxn,
                                 derivFxnCall,
                                 order,
                                 wrt,
                                 fxnEnv) {
  fA <- formalArgs(derivFxn)

  ## convert 'x[2]' to quote(x[2]).
  wrt_code <- lapply(wrt,
                     function(x) parse(text = x, keep.source = FALSE)[[1]])
  
  ## convert quote(x[2]) to quote(x)
  wrt_names <- lapply(wrt_code,
                      function(x) if(is.name(x)) x else x[[2]])
  
  wrt_name_strings <- as.character(wrt_names)
  
  ## Get unique names and track indices from wrt to unique names
  wrt_unique_names <- unique(wrt_names)
  
  if(!all(wrt_unique_names %in% fA)){
    stop('Error:  the wrt argument to nimDerivs() contains names that are not
         arguments to the nimFxn argument.')
  }
  
  wrt_names_orig_indices <- match(wrt_names, wrt_unique_names)
  wrt_unique_name_strings <- as.character(wrt_unique_names)

  ## get the user-supplied arguments to the nFunction
  derivFxnCall_args <- as.list(derivFxnCall)[-1]

  ## Get the value of the user-supplied args
  ## These will be modified in func.
  ## This is ordered by fA, the arguments of the target function
  fxnArgs <- lapply(derivFxnCall_args, function(x) eval(x, envir = fxnEnv))
  names(fxnArgs) <- as.character(fA)
  
  ## Get length or dimensions
  ## ordered by fxnArgs
  unique_dims <- sapply(fxnArgs,
                 function(x) 
                     dimOrLength(x))

  ## Get product of dimensions, which we call size
  ## ordered by fxnArgs
  unique_sizes <- sapply(unique_dims, prod)

  ## 
  scalar_index_objects <- lapply(unique_dims,
                                 function(x)
                                     array(1:prod(x), dim = x))

  eval_env <- new.env()
  ## iterate over names in any wrt.

  current_x_index <- 1
  fxnArgs_assign_code <- list()
  get_init_values_code <- list()
  result_x_indices_by_wrt <- list()
  for(i in seq_along(wrt_unique_names)) {
      ## Below, x represents the input argument to func
      this_unique_wrt_string <- wrt_unique_name_strings[i]
      ##
      dims <- dimOrLength(fxnArgs[[this_unique_wrt_string]])
      ##
      assign(this_unique_wrt_string, array(1:prod(dims), dim = dims), envir = eval_env)
      
      ## which wrt arguments use this wrt variable
      i_wrt_orig <- which(this_unique_wrt_string == wrt_name_strings)     

      flat_indices <- lapply(wrt_code[i_wrt_orig], ## quote(x[2]) and so on for any wrt using x
                             function(wrt_code_) {
                                 eval(wrt_code_, envir = eval_env)
                             })
      unique_flat_indices <- unique(unlist(flat_indices))
      ## I is this_unique_wrt_string
      ## J is an element of unique_flat_indices
      ## K is a running element of x

      x_indices <- current_x_index - 1 + 1:length(unique_flat_indices)
      
      fxnArgs_assign_code[[i]] <- mapply(
          function(jval, kval)
              substitute(fxnArgs[[ I ]][J] <<- x[K], list(I = this_unique_wrt_string,
                                                         J = jval,
                                                         K = kval)),
          unique_flat_indices,
          x_indices)
      get_init_values_code[[i]] <- mapply(
          function(jval, kval)
              substitute(currentX[K] <- fxnArgs[[ I ]][J], list(I = this_unique_wrt_string,
                                                             J = jval,
                                                             K = kval)),
          unique_flat_indices,
          x_indices)

      result_x_indices <- lapply(flat_indices,
                                      function(fi) x_indices[match(fi, unique_flat_indices)]
                                      )
      result_x_indices_by_wrt[i_wrt_orig] <- result_x_indices
      
     current_x_index <- current_x_index + length(unique_flat_indices)
  }
  fxnArgs_assign_code <- unlist(fxnArgs_assign_code, recursive = FALSE)
  get_init_values_code <- unlist(get_init_values_code, recursive = FALSE)
  result_x_indices_all <- unlist(result_x_indices_by_wrt)

  length_x <- current_x_index - 1
  currentX <- numeric(length_x)
  do.call("{", get_init_values_code)
  ## equivalent to:
  ##  for(i in 1:length_x) {
  ##      eval(get_init_values_code[[i]])
  ##  }

  derivFxnName <- as.character(derivFxnCall[[1]])
  func <- function(x) {
      do.call("{", fxnArgs_assign_code)
      ## for(i in 1:length_x) {
      ##     ## Each line is like fxnArgs[[ 2 ]][23] <- x[5]
      ##     eval(fxnArgs_assign_code[[i]])
      ## }
      c(do.call(derivFxnName, fxnArgs, envir = fxnEnv)) #c() unrolls any answer to a vector
  }

  ans <- calcDerivs_internal(func, currentX, order, result_x_indices_all)
##  jacobian(func, currentX)
  ans 
}


nimDerivs_nf_numericwrt <- function(derivFxn,
                                    derivFxnCall,
                                    order,
                                    wrt,
                                    fxnEnv) {
  fA <- formalArgs(derivFxn)
  derivFxnCall_args <- as.list(derivFxnCall)[-1]

  fxnArgs <- lapply(derivFxnCall_args, function(x) eval(x, envir = fxnEnv))
  names(fxnArgs) <- as.character(fA)

  argSizes <- unlist(lapply(fxnArgs, function(x) prod(dimOrLength(x))))

  flatStartIndices <- cumsum(c(1, argSizes[-length(argSizes)]))
  names(flatStartIndices) <- names(fxnArgs)
  wrtIndexToArgName <- character()
  for(i in seq_along(fxnArgs)) {
    wrtIndexToArgName <- c(wrtIndexToArgName, rep(names(fxnArgs)[i], prod(dimOrLength(fxnArgs[[i]]))))
  }

  ## We need lines like
  ## fxnArgs[[ argName ]][j] <- x[k]
  ## j is the flat index of the variable
  ## k is the index of the concatenated x
  wrtUnique <- unique(wrt)
  fxnArgs_assign_code <- list()
  get_init_values_code <- list()
  length_result <- length(wrt)
  result_x_indices <- integer(length_result)
  for(iWrt in seq_along(wrtUnique)) {
    wrtIndex <- wrtUnique[iWrt]
    i_wrt_orig <- which(wrtIndex == wrt) ## indices of original, allowing repeats
    argName <- wrtIndexToArgName[wrtIndex]
    j <- wrtIndex - flatStartIndices[argName] + 1
    fxnArgs_assign_code[[ length(fxnArgs_assign_code) + 1]] <-
      substitute(fxnArgs[[ ARGNAME ]][J] <- x[K],
                 list(ARGNAME = argName,
                      J = as.integer(j),
                      K = iWrt))
    get_init_values_code[[ length(fxnArgs_assign_code) + 1]] <-
      substitute(currentX[K] <- fxnArgs[[ ARGNAME ]][J],
                 list(ARGNAME = argName,
                      J = as.integer(j),
                      K = iWrt))
    result_x_indices[i_wrt_orig] <- iWrt
    ## still need to add result_x_indices to deal with redundant or disordered requests
  }
  length_x <- length(wrtUnique)
  currentX <- numeric(length_x)
  do.call("{", get_init_values_code)
  
  derivFxnName <- as.character(derivFxnCall[[1]])
  func <- function(x) {
    do.call("{", fxnArgs_assign_code)
    c(do.call(derivFxnName, fxnArgs, envir = fxnEnv)) #c() unrolls any answer to a vector
  }
  ans <- calcDerivs_internal(func, currentX, order, result_x_indices)
  ans 
}

nimDerivs_nf <- function(call = NA,
                         wrt = NA,
                         order = nimC(0,1,2),
                         derivFxnCall = NULL,
                         fxnEnv = parent.frame(),
                         model = NA){
  fxnCall <- match.call()
  ## This may be called directly (for testing) or from nimDerivs (typically).
  ## In the former case, we get derivFxnCall from the nimFxn argument.
  ## In the latter case, it is already match.call()ed and is passed here via derivFxnCall
  if(is.null(derivFxnCall))
    derivFxnCall <- fxnCall[['call']] ## either calculate(model, nodes) or model$calculate(nodes)
  else
    if(is.null(fxnCall[['order']]))
      fxnCall[['order']] <- order ## I don't think this is meaningful any more. fxnCall is not used further

  ## get the actual function
  derivFxn <- eval(derivFxnCall[[1]], envir = fxnEnv)

  ## Initialize information for restoring model when needed
  e <- environment(derivFxn)
  ## There might be no model in the inner nimDerivs call, so need to check for e$restoreInfo.
  if(inherits(model, 'modelBaseClass') || exists('restoreInfo', e)) {
      if(!exists('restoreInfo', e)) {
          ## First nimDerivs call: save model state and initialize derivative status.
          e$restoreInfo <- new.env()
          e$restoreInfo$currentDepth <- 1
          e$restoreInfo$deepestDepth <- 1
          vars <- model$getVarNames(includeLogProb = TRUE)
          e$restoreInfo$values <- list()
          length(e$restoreInfo$values) <- length(vars)
          for(v in seq_along(vars)) 
              e$restoreInfo$values[[v]] <- model[[vars[v]]]
          names(e$restoreInfo$values) <- vars
      } else {
          ## We are in nested calls to nimDerivs.
          e$restoreInfo$currentDepth <- e$restoreInfo$currentDepth + 1
          if(e$restoreInfo$deepestDepth < e$restoreInfo$currentDepth)
              e$restoreInfo$deepestDepth <- e$restoreInfo$currentDepth
      }
  } else { # partial check for whether there is a model in the nimbleFunction
      if(is(derivFxn, 'refMethodDef') && is.nf(e$.self)) {
          isModel <- sapply(names(e), function(x) is.model(e[[x]]))
          if(any(isModel)) {
              modelElement <- names(e)[which(isModel)]
              warning("nimDerivs_nf: detected a model, ", paste(modelElement, collapse = ','), ", associated with the nimbleFunction whose method is being differentiated. If model calculations are done in the method being differentiated, the 'model' argument to 'nimDerivs' should be included to ensure correct restoration of values in the model.")
          }
      }
  }
  
  ## standardize the derivFxnCall arguments
  derivFxnCall <- match.call(derivFxn, derivFxnCall)

  fA <- formalArgs(derivFxn)
  if(is.null(wrt) || (length(wrt)==1 && is.na(wrt[1]))) {
    wrt <- fA
  }
  if(is.character(wrt)) {
    ans <- nimDerivs_nf_charwrt(derivFxn,
                                derivFxnCall,
                                order,
                                wrt,
                                fxnEnv)
  } else if(is.numeric(wrt)) {
    ans <- nimDerivs_nf_numericwrt(derivFxn,
                                   derivFxnCall,
                                   order,
                                   wrt,
                                   fxnEnv)
  } else {
    stop("Invalid wrt argument in nimDerivs_nf")
  }

  ## Restore model state if appropriate and update restoration info
  if(exists('restoreInfo', e)) {
      ## Restore model state if double-taping or not asking for order 0 in single tape.
      ## Per NCT issue 256, compiled double-taping with inner tape = 0 does not update model,
      ## so mimicing that here.
      ## It's possible no model was provided to the nested nimDerivs call.
    if((inherits(model, 'modelBaseClass')) &&
       (e$restoreInfo$currentDepth > 1 || e$restoreInfo$deepestDepth > 1 || !0 %in% order)) {
          for(v in seq_along(e$restoreInfo$values)) {
              nm <- names(e$restoreInfo$values)[v]
              model[[nm]] <- e$restoreInfo$values[[nm]]
          }
      }
      
      if(e$restoreInfo$currentDepth == 1) {
          rm('restoreInfo', pos = e)
      } else {
          e$restoreInfo$currentDepth <- e$restoreInfo$currentDepth - 1
      }
  }
  ans
}

convertWrtArgToIndices <- function(wrtArgs, nimFxnArgs, fxnName){
  ## If a vector of wrt args, get individual args.
  if(deparse(wrtArgs[[1]]) == 'nimC'){ 
    wrtArgs <- sapply(wrtArgs[-1], function(x){as.character(x)})
  }
  if(all(is.na(wrtArgs))){
    return(-1)
  }
  else if(is.null(wrtArgs[[1]])){
    wrtArgs <- names(nimFxnArgs)
  }
  else if(!is.character(wrtArgs)){
    wrtArgs <- deparse(wrtArgs)
  }
  ## Get names of wrt args with indexing removed.
    ## wrtArgNames <- sapply(wrtArgs, function(x){strsplit(x, '\\[')[[1]][1]})
    wrtArgParsed <- lapply(wrtArgs, function(x) {
        pvl <- .Call(makeParsedVarList, x)[[2]][[2]]
        if(is.name(pvl))
            list(argName = deparse(pvl),
                 argIndices = NULL)
        else
            list(argName = deparse(pvl[[2]]),
                 argIndices = as.list(pvl)[-c(1:2)])
    }
    )
    wrtArgNames <- unlist(lapply(wrtArgParsed, `[[`, 'argName'))
  nimFxnArgNames <- names(nimFxnArgs)
  wrtMatchArgs <- which(nimFxnArgNames %in% wrtArgNames)
  argNameCheck <- wrtArgNames %in% nimFxnArgNames
  ## Compare wrt args to actual function args and make sure no erroneous args
  ## are present.
  if(!all(argNameCheck)) stop('Incorrect names passed to wrt argument of
                                     nimDerivs: ', fxnName, 
                              ' does not have arguments named: ',
                              paste(wrtArgNames[!argNameCheck], 
                                    collapse = ', ' ), '.')
  ## Make sure all wrt args have type declarations (i.e., are not just names without types)
    nameCheck <- sapply(wrtMatchArgs, function(x){return(class(nimFxnArgs[[x]]))})
    if(any(nameCheck == 'name')) stop('Derivatives of ', fxnName, ' being taken 
                                    WRT an argument that does not have a valid type.')
    ## Make sure all wrt args have type double.
    doubleCheck <- sapply(wrtMatchArgs, function(x){
    return(deparse(nimFxnArgs[[x]][[1]]) == 'double')})
    if(!all(doubleCheck)) stop('Derivatives of ', fxnName, 
                               ' being taken WRT an argument that does not
                                    have type double().')
  ## Make sure that all wrt arg dims are < 2.
  fxnArgsDims <- sapply(wrtMatchArgs, function(x){
    if(length(nimFxnArgs[[x]]) == 1) outDim <- 0
    else outDim <- nimFxnArgs[[x]][[2]]
    names(outDim) <- names(nimFxnArgs)[x]
    return(outDim)})
  
  if(any(fxnArgsDims > 2)) stop('Derivatives cannot be taken WRT an argument
                                with dimension > 2')
  ## Determine sizes of each function arg.
  fxnArgsDimSizes <- lapply(nimFxnArgs, function(x){
    if(length(x) == 1) return(1)
    if(x[[2]] == 0) return(1)
    if(length(x) < 3) stop(paste0('Problem building derivatives for nimbleFunction method ', fxnName,
                                 ': If the mode of buildDerivs is static=TRUE, sizes of arguments ',
                                 'must be explicitly specified (e.g. x = double(1, 4)).  If the mode of ',
                                 'buildDerivs is static=FALSE, then wrt cannot use argument ',
                                 'names.'))
    if(x[[2]] == 1){
      if(length(x[[3]]) == 1) return(x[[3]])
      else return(x[[3]][[2]])
    }
    else if(x[[2]] == 2) return(c(x[[3]][[2]], x[[3]][[3]]))
  })
  ## Same as above sizes, except that matrix sizes are reported as nrow*ncol
  ## instead of c(nrow, ncol).
  fxnArgsTotalSizes <- sapply(fxnArgsDimSizes, function(x){
    return(prod(x))
  }
  )
  fxnArgsTotalSizes <- c(0,fxnArgsTotalSizes)
  ## fxnArgsIndexVector is a named vector with the starting index of each
  ## argument to the nimFxn,
  ## if all arguments were flattened and put into a single vector.
  ## E.g. if nimFxn has arguments x = double(2, c(2, 2)), y = double(1, 3), 
  ## and z = double(0),
  ## then fxnArgsIndexVector will be:    
  ##   x  y  z
  ##   1  5  8
  fxnArgsIndexVector <- cumsum(fxnArgsTotalSizes) + 1
  names(fxnArgsIndexVector) <- c(names(fxnArgsIndexVector)[-1], '')
  fxnArgsIndexVector <- fxnArgsIndexVector[-length(fxnArgsIndexVector)]
  wrtArgsIndexVector <- c()
  for(i in seq_along(wrtArgNames)){
    if(fxnArgsDims[wrtArgNames[i]] == 0){
      wrtArgsIndexVector <- c(wrtArgsIndexVector, 
                              fxnArgsIndexVector[wrtArgNames[i]])
    }
    else if(fxnArgsDims[wrtArgNames[i]] == 1){
      hasIndex <- !is.null(wrtArgParsed[[i]]$argIndices) #grepl('^[a-z]*\\[', wrtArgs[i])
      if(hasIndex){
          if(length(wrtArgParsed[[i]]$argIndices) != 1)  stop(paste0('Incorrect indexing ',
                                                          'provided for wrt ',
                                                          'argument to ',
                                                          'nimDerivs(): ',
                                                          wrtArgs[i]))
          ## argIndices <- eval(parse(text = sub('\\]', '', sub('^[a-z]*\\[','', wrtArgs[i]))))
          if(is.blank(wrtArgParsed[[i]]$argIndices[[1]]))
              argIndices <- 1:fxnArgsTotalSizes[wrtArgNames[i]]
          else {
              argIndices <- eval(wrtArgParsed[[i]]$argIndices[[1]])
              if(any(argIndices < 1 | argIndices > fxnArgsTotalSizes[wrtArgNames[i]]))
                  stop(paste0('Index too large (or < 0) provided in wrt argument: ',
                              wrtArgs[i], ' for derivatives of ', fxnName))
          }
        wrtArgsIndexVector <- c(wrtArgsIndexVector, 
                                fxnArgsIndexVector[wrtArgNames[i]] +
                                  argIndices - 1)
      }
      else{
        wrtArgsIndexVector <- c(wrtArgsIndexVector, 
                                fxnArgsIndexVector[wrtArgNames[i]] +
                                  0:(fxnArgsTotalSizes[wrtArgNames[i]] - 1))
      }
    }
    else if(fxnArgsDims[wrtArgNames[i]] == 2){
      hasIndex <- !is.null(wrtArgParsed[[i]]$argIndices) #grepl('^[a-z]*\\[', wrtArgs[i])
      if(hasIndex){
##        argIndicesText <- strsplit(sub('\\]', '', sub('^[a-z]*\\[', '',
##                                                      wrtArgs[i])), ',',
##                                   fixed = TRUE)
        if(length(wrtArgParsed[[i]]$argIndices) != 2)  stop(paste0('Incorrect indexing ',
                                                          'provided for wrt ',
                                                          'argument to ',
                                                          'nimDerivs(): ',
                                                          wrtArgs[i]))
          if(is.blank(wrtArgParsed[[i]]$argIndices[[1]]))
              argIndicesRows <- 1:fxnArgsDimSizes[[wrtArgNames[i]]][1]
          else {
              argIndicesRows <- eval(wrtArgParsed[[i]]$argIndices[[1]])
              if(any(argIndicesRows < 1 | argIndicesRows > fxnArgsDimSizes[[wrtArgNames[i]]][1]))
                  stop(paste0('A row index that is too large (or < 0) ',
                              ' was provided in wrt argument: ',
                              wrtArgs[i], ' for derivatives of ', fxnName))
          }
          if(is.blank(wrtArgParsed[[i]]$argIndices[[2]]))
              argIndicesCols <- 1:fxnArgsDimSizes[[wrtArgNames[i]]][2]
          else { 
              argIndicesCols <- eval(wrtArgParsed[[i]]$argIndices[[2]])
              if(any(argIndicesCols < 1 | argIndicesCols > fxnArgsDimSizes[[wrtArgNames[i]]][2]))
                  stop(paste0('A column index that is too large (or < 0) ',
                              ' was provided in wrt argument: ',
                              wrtArgs[i], ' for derivatives of ', fxnName))
          
          }
        ## Column major ordering
        for(col in argIndicesCols){
          wrtArgsIndexVector <- c(wrtArgsIndexVector, 
                                  fxnArgsIndexVector[wrtArgNames[i]] + 
                                    (col - 1)*
                                    fxnArgsDimSizes[[wrtArgNames[i]]][1] +
                                    argIndicesRows - 1)
        }
      }
      else{
        wrtArgsIndexVector <- c(wrtArgsIndexVector,  
                                fxnArgsIndexVector[wrtArgNames[i]] +
                                  0:(fxnArgsTotalSizes[wrtArgNames[i]] - 1))
      }
    }
  }
  return(unname(wrtArgsIndexVector))
}

Try the nimble package in your browser

Any scripts or data that you put into this service are public.

nimble documentation built on July 9, 2023, 5:24 p.m.