Taylor Mode Autodiff in PyTorch
This library provides a PyTorch implementation of Taylor mode automatic differentiation, a generalization of forward mode to higher-order derivatives.
It is similar to JAX's Taylor mode (jax.experimental.jet).
The repository also hosts the Python functionality+experiments and LaTeX source for our NeurIPS 2025 paper "Collapsing Taylor Mode Automatic Differentiation", which allows to further accelerate Taylor mode for many practical differential operators.
🔪 Warning: expect rough edges! 🔪
This is a research prototype with various limitations (e.g. operator coverage). We highly recommend double-checking your results with PyTorch's autodiff. Please help us improve the package by providing feedback, filing issues, and opening pull requests.
Getting Started
Installation
Examples
See the documentation.
Citing
If you find the jet package useful for your research, consider citing