Taylor Mode Autodiff in PyTorch
This library provides a PyTorch implementation of Taylor mode automatic differentiation, a generalization of forward mode to higher-order derivatives.
It is similar to JAX's Taylor mode (jax.experimental.jet).
The repository also hosts the Python functionality+experiments and LaTeX source for our NeurIPS 2025 paper "Collapsing Taylor Mode Automatic Differentiation", which allows to further accelerate Taylor mode for many practical differential operators.
Operator coverage is still growing, so please help us improve the package by providing feedback, filing issues, and opening pull requests.
Getting Started
Installation
Quickstart
Compute the third-order derivative of a scalar function with jet:
from torch import tensor, ones_like, zeros_like
from jet import jet
f = lambda x: x**4 # f'''(x) = 24 * x
x = tensor(2.0)
f_jet = jet(f, (x,)) # (x₀, x₁, x₂, x₃) ↦ (f₀, f₁, f₂, f₃)
# x₀ = x, x₁ = 1, x₂ = x₃ = 0 ⇒ f₃ = f'''(x)
f0, f1, f2, f3 = f_jet((x, ones_like(x), zeros_like(x), zeros_like(x)))
assert f3.allclose(24 * x) # third derivative matches by hand
For the full list of supported operators, see the supported operations section of the introductory tutorial.
Examples
See the documentation.
Citing
If you find the jet package useful for your research, consider citing