API Documentation
Function transformations
jet.jet
jet(
f: Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]],
mock_args: tuple[PyTree[Tensor], ...],
collapsed: bool = False,
) -> Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]
Overload a function with its Taylor-mode equivalent.
The returned function is K-polymorphic: the derivative order is inferred
per call from the number of coefficients in the input jet tuples. The
collapsed flag selects between two propagation regimes (the
coefficient shape contract for each mode is given below); to freeze the
order into an FX GraphModule (e.g. for graph passes like CSE), apply
:func:capture_graph to the returned callable yourself.
- standard mode (
collapsed=False): every coefficientc_khas the primal's shapeS. - collapsed mode (
collapsed=True): coefficientsc_1..c_{K-1}have shape(R, *S)carryingRdirections;c_Khas shapeS(already summed over the directions). The K-th output coefficient is likewise returned collapsed. This exploits that the highest-order coefficient enters linearly, so it can be summed eagerly to propagate smaller tensors through the graph. RequiresK >= 2andRmay vary per call.
Parameters:
-
f(Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]]) –Function to overload. May accept and return pytrees of tensors.
-
mock_args(tuple[PyTree[Tensor], ...]) –Mock input tensors (or pytrees of tensors) for tracing
f's compute graph, provided as a tuple matching the positional arguments off. Only shapes and dtypes matter, not the values. -
collapsed(bool, default:False) –Select between the two propagation regimes above. Default:
False(standard mode).
Returns:
-
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–A callable
jet_f(*args)taking one positional argument per -
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–argument of
f. Each argument is a pytree mirroring the -
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–corresponding
mock_argsentry but with every tensor leaf -
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–replaced by a tuple
(primal, c_1, ..., c_K)bundling the primal -
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–with its
KTaylor coefficients. Returns a pytree mirroringf's -
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–output structure with each tensor leaf replaced by
-
Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]–(f_0, f_1, ..., f_K).
Examples:
Single-input::
>>> from torch import sin, zeros, Tensor
>>> from jet import jet
>>> jet_f = jet(sin, (zeros(1),))
>>> x0, x1, x2 = Tensor([0.123]), Tensor([-0.456]), Tensor([0.789])
>>> f0, f1, f2 = jet_f((x0, x1, x2))
Multi-input::
>>> from torch import cos
>>> f = lambda x, y: sin(x) * cos(y)
>>> jet_f = jet(f, (zeros(3), zeros(3)))
>>> x, y = Tensor([0.1, 0.2, 0.3]), Tensor([0.4, 0.5, 0.6])
>>> vx, vy = Tensor([1.0, 0.0, 0.0]), Tensor([0.0, 1.0, 0.0])
>>> f0, f1 = jet_f((x, vx), (y, vy))
Source code in jet/__init__.py
jet.laplacian.laplacian
laplacian(
f: Callable[[Tensor], Tensor],
mock_args: tuple[PyTree[Tensor], ...],
randomization: tuple[str, int] | None = None,
weighting: tuple[Callable[[Tensor, Tensor], Tensor], int] | None = None,
collapsed: bool = True,
) -> Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]
Transform f into a function that computes lap(f(x)).
The Laplacian of a function \(f(\mathbf{x}) \in \mathbb{R}\) with \(\mathbf{x} \in \mathbb{R}^D\) is defined as the Hessian trace, or
For functions that produce vectors or tensors, the Laplacian is defined per output component and has the same shape as \(f(\mathbf{x})\).
Only single-tensor functions (one tensor in, one tensor out) are supported.
Parameters:
-
f(Callable[[Tensor], Tensor]) –The function whose Laplacian is computed. Must consume and return a single tensor.
-
mock_args(tuple[PyTree[Tensor], ...]) –Mock positional arguments for tracing
f, provided as a tuple matchingf's positional arguments. Does not need to be the actual input; only shapes and dtypes matter. Currently must be a one-tuple of a single tensor. -
randomization(tuple[str, int] | None, default:None) –Optional tuple containing the distribution type and number of samples for randomized Laplacian. If provided, the Laplacian will be computed using Monte-Carlo sampling. The first element is the distribution type (e.g., 'normal', 'rademacher'), and the second is the number of samples to use.
-
weighting(tuple[Callable[[Tensor, Tensor], Tensor], int] | None, default:None) –A tuple specifying how the second-order derivatives should be weighted. This is described by a coefficient tensor C(x) of shape
[*D, *D]. The first entry is a function (x, V) -> V @ S(x).T that applies the symmetric factorization S(x) of the weights C(x) = S(x) @ S(x).T at the input x to the matrix V. S(x) has shape[*D, rank_C]while V is[K, rank_C]with arbitraryK. The second entry specifiesrank_C. IfNone, then the weightings correspond to the identity matrix (i.e. computing the standard Laplacian). -
collapsed(bool, default:True) –Whether to use collapsed Taylor mode. If
True(default), uses the collapsed dispatch path (JetInterpreter(..., collapsed=True)) that directly propagates the summed second-order coefficient. IfFalse, propagates full 2-jets over all directions viavmapand sums afterward.
Returns:
-
Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]–A plain Python callable
lap_f(*args)that mapsx → lap(f(x)). -
Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]–To bake the operator into an FX
GraphModule(for graph passes, -
Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]–torch.compile, etc.), apply :func:capture_graphyourself.
Examples:
>>> from torch import manual_seed, rand, zeros
>>> from torch.func import hessian
>>> from torch.nn import Linear, Tanh, Sequential
>>> from jet.laplacian import laplacian
>>> _ = manual_seed(0) # make deterministic
>>> f = Sequential(Linear(3, 1), Tanh())
>>> x0 = rand(3)
>>> # Compute the Laplacian via Taylor mode
>>> lap = laplacian(f, (zeros(3),))(x0)
>>> assert lap.shape == f(x0).shape
>>> # Compute the Laplacian with PyTorch's autodiff (Hessian trace)
>>> lap_pt = hessian(f)(x0).squeeze(0).trace().unsqueeze(0)
>>> assert lap.shape == lap_pt.shape
>>> assert lap_pt.allclose(lap)
Source code in jet/laplacian.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | |
jet.bilaplacian.bilaplacian
bilaplacian(
f: Callable[[Tensor], Tensor],
mock_args: tuple[PyTree[Tensor], ...],
randomization: tuple[str, int] | None = None,
collapsed: bool = True,
) -> Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]
Transform f into a function that computes the Bi-Laplacian.
The Bi-Laplacian of a function \(f(\mathbf{x}) \in \mathbb{R}\) with \(\mathbf{x} \in \mathbb{R}^D\) is defined as the Laplacian of the Laplacian, or
For functions that produce vectors or tensors, the Bi-Laplacian is defined per output component and has the same shape as \(f(\mathbf{x})\).
Only single-tensor functions (one tensor in, one tensor out) are supported.
Parameters:
-
f(Callable[[Tensor], Tensor]) –The function whose Bi-Laplacian is computed. Must consume and return a single tensor.
-
mock_args(tuple[PyTree[Tensor], ...]) –Mock positional arguments for tracing
f, provided as a tuple matchingf's positional arguments. Only shapes and dtypes matter, not the values. Currently must be a one-tuple of a single tensor. -
randomization(tuple[str, int] | None, default:None) –Optional tuple containing the distribution type and number of samples for randomized Bi-Laplacian. If provided, the Bi-Laplacian will be computed using Monte-Carlo sampling. The first element is the distribution type (must be 'normal'), and the second is the number of samples to use. Default is
None. -
collapsed(bool, default:True) –Whether to use collapsed Taylor mode. If
True(default), uses the collapsed dispatch path (JetInterpreter(..., collapsed=True)) that directly propagates the summed fourth-order coefficient. IfFalse, propagates full 4-jets over all directions viavmapand sums afterward.
Returns:
-
Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]–A plain Python callable
bilap_f(*args)that mapsx → bilap(f(x)). -
Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]–To bake the operator into an FX
GraphModule(for graph passes, -
Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]–torch.compile, etc.), apply :func:capture_graphyourself.
Examples:
>>> from torch import manual_seed, rand, zeros
>>> from torch.func import hessian
>>> from torch.nn import Linear, Tanh, Sequential
>>> from jet.bilaplacian import bilaplacian
>>> _ = manual_seed(0) # make deterministic
>>> f = Sequential(Linear(3, 1), Tanh())
>>> x0 = rand(3)
>>> # Compute the Bilaplacian via Taylor mode
>>> bilap = bilaplacian(f, (zeros(3),))(x0)
>>> assert bilap.shape == f(x0).shape
>>> # Compute the Bilaplacian with PyTorch's autodiff
>>> laplacian_pt = lambda x: hessian(f)(x).squeeze(0).trace().unsqueeze(0)
>>> bilaplacian_pt = hessian(laplacian_pt)(x0).squeeze(0).trace().unsqueeze(0)
>>> assert bilap.shape == bilaplacian_pt.shape
>>> assert bilaplacian_pt.allclose(bilap)
Source code in jet/bilaplacian.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | |
Graph capture
jet.capture_graph
capture_graph(
f: Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]],
mock_args: tuple[PyTree[Tensor], ...],
) -> tuple[GraphModule, TreeSpec]
Capture the compute graph of f as a GraphModule.
The returned GraphModule's forward takes the flat tensor leaves
of mock_args as positional arguments -- make_fx cannot trace
through pytree containers, so mock_args is flattened and f is
traced over those flat leaves. To call the captured graph with f's
original pytree shape, use the second return value, in_spec, to
flatten new arguments in the order the graph expects::
mod, in_spec = capture_graph(f, mock_args)
out = mod(*in_spec.flatten_up_to(args))
Output structure is passed through unchanged -- whatever f returns
(single tensor, tuple, dict, arbitrary pytree), make_fx's pytree
codegen reconstructs on each call.
Parameters:
-
f(Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]]) –Callable to trace (plain function,
nn.Module,GraphModule, etc.). May accept pytrees of tensors as positional args. -
mock_args(tuple[PyTree[Tensor], ...]) –Mock inputs (pytrees of tensors) matching
f's positional args, provided as a tuple. Only shapes and dtypes matter.
Returns:
-
GraphModule–A pair
(mod, in_spec)wheremodis aGraphModulewhose -
TreeSpec–forward takes the flat tensor leaves of
mock_args, and -
tuple[GraphModule, TreeSpec]–in_specis theTreeSpecofmock_args. Use -
tuple[GraphModule, TreeSpec]–in_spec.flatten_up_to(args)to obtain the flat leaves in the -
tuple[GraphModule, TreeSpec]–order
modexpects.
Raises:
-
TypeError–If
mock_argsis not atuple. Wrap a single positional argument as(x,).
Source code in jet/tracing.py
jet.simplify.common_subexpression_elimination
Replace duplicate subexpressions with a single node.
Parameters:
-
graph(Graph) –The graph to be optimized.
-
verbose(bool, default:False) –Whether to print debug information. Default:
False.
Returns:
-
bool–Whether a subexpression was replaced.
Source code in jet/simplify.py
jet.utils.visualize_graph
Visualize the compute graph of a module.
Supported formats: .png, .pdf, .svg (inferred from savefile).
Parameters:
-
mod(GraphModule) –The module whose compute graph to visualize.
-
savefile(str) –The path to the file where the graph should be saved.
-
name(str, default:'') –A name for the graph, used in the visualization.
-
use_custom(bool, default:False) –If
True, highlight sum nodes in orange-red and use white for other operations. Defaults toFalse.
Raises:
-
ValueError–If savefile has an unsupported extension.