autodiff | R Documentation |
This function enables or disables automatic differentiation using the JAX package in Python, which can considerably speed up and increase the accuracy of standard errors when a model includes many parameters.
autodiff(autodiff = NULL, install = FALSE)
autodiff |
Logical flag. If |
install |
Logical flag. If |
When autodiff = TRUE
, this function:
Imports the marginaleffectsAD
Python package via reticulate::py_install()
Sets the internal jacobian function to use JAX-based automatic differentiation
Provides faster and more accurate gradient computation for supported models
Falls back on the default finite difference method for unsupported models and calls.
Currently supports:
Model types: lm
, glm
, ols
, lrm
Functions: predictions()
and comparisons()
, along with avg_
and plot_
variants.
type
: "response" or "link"
by
: TRUE
, FALSE
, or character vector.
comparison
: "difference" and "ratio"
For unsupported models or options, the function automatically falls back to finite difference methods with a warning.
No return value. Called for side effects of enabling/disabling automatic differentiation.
By default, no manual configuration of Python should be necessary. On most
machines, unless you have explicitly configured reticulate
, reticulate
defaults to an automatically managed ephemeral virtual environment with all
Python requirements declared via reticulate::py_require()
.
If you prefer to use a manually managed Python installation, you can direct
reticulate
and specify which Python executable or environment to use.
reticulate
selects a Python installation using its Order of Discovery.
As a convenience autodiff(install=TRUE)
will install marginaleffectsAD
in
a self-managed virtual environment.
To specify an alternate Python version:
library(reticulate) use_python("/usr/local/bin/python")
To use a virtual environment:
use_virtualenv("myenv")
These configuration commands should be called before calling autodiff()
.
## Not run:
# Install the Python package (only needed once)
autodiff(install = TRUE)
# Enable automatic differentiation
autodiff(TRUE)
# Fit a model and compute marginal effects
mod <- glm(am ~ hp + wt, data = mtcars, family = binomial)
avg_comparisons(mod) # Will use JAX for faster computation
# Disable automatic differentiation
autodiff(FALSE)
## End(Not run)
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.