Skip to content

API Documentation

Function transformations

jet.jet

jet(
    f: Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]],
    mock_args: tuple[PyTree[Tensor], ...],
    collapsed: bool = False,
) -> Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

Overload a function with its Taylor-mode equivalent.

The returned function is K-polymorphic: the derivative order is inferred per call from the number of coefficients in the input jet tuples. The collapsed flag selects between two propagation regimes (the coefficient shape contract for each mode is given below); to freeze the order into an FX GraphModule (e.g. for graph passes like CSE), apply :func:capture_graph to the returned callable yourself.

  • standard mode (collapsed=False): every coefficient c_k has the primal's shape S.
  • collapsed mode (collapsed=True): coefficients c_1..c_{K-1} have shape (R, *S) carrying R directions; c_K has shape S (already summed over the directions). The K-th output coefficient is likewise returned collapsed. This exploits that the highest-order coefficient enters linearly, so it can be summed eagerly to propagate smaller tensors through the graph. Requires K >= 2 and R may vary per call.

Parameters:

  • f (Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]]) –

    Function to overload. May accept and return pytrees of tensors.

  • mock_args (tuple[PyTree[Tensor], ...]) –

    Mock input tensors (or pytrees of tensors) for tracing f's compute graph, provided as a tuple matching the positional arguments of f. Only shapes and dtypes matter, not the values.

  • collapsed (bool, default: False ) –

    Select between the two propagation regimes above. Default: False (standard mode).

Returns:

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    A callable jet_f(*args) taking one positional argument per

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    argument of f. Each argument is a pytree mirroring the

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    corresponding mock_args entry but with every tensor leaf

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    replaced by a tuple (primal, c_1, ..., c_K) bundling the primal

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    with its K Taylor coefficients. Returns a pytree mirroring f's

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    output structure with each tensor leaf replaced by

  • Callable[[*(tuple[PyTree[Jet], ...])], PyTree[Jet]]

    (f_0, f_1, ..., f_K).

Examples:

Single-input::

>>> from torch import sin, zeros, Tensor
>>> from jet import jet
>>> jet_f = jet(sin, (zeros(1),))
>>> x0, x1, x2 = Tensor([0.123]), Tensor([-0.456]), Tensor([0.789])
>>> f0, f1, f2 = jet_f((x0, x1, x2))

Multi-input::

>>> from torch import cos
>>> f = lambda x, y: sin(x) * cos(y)
>>> jet_f = jet(f, (zeros(3), zeros(3)))
>>> x, y = Tensor([0.1, 0.2, 0.3]), Tensor([0.4, 0.5, 0.6])
>>> vx, vy = Tensor([1.0, 0.0, 0.0]), Tensor([0.0, 1.0, 0.0])
>>> f0, f1 = jet_f((x, vx), (y, vy))
Source code in jet/__init__.py
def jet(
    f: Callable[[*tuple[PyTree[Tensor], ...]], PyTree[Tensor]],
    mock_args: tuple[PyTree[Tensor], ...],
    collapsed: bool = False,
) -> Callable[[*tuple[PyTree[Jet], ...]], PyTree[Jet]]:
    """Overload a function with its Taylor-mode equivalent.

    The returned function is K-polymorphic: the derivative order is inferred
    per call from the number of coefficients in the input jet tuples. The
    ``collapsed`` flag selects between two propagation regimes (the
    coefficient shape contract for each mode is given below); to freeze the
    order into an FX ``GraphModule`` (e.g. for graph passes like CSE), apply
    :func:`capture_graph` to the returned callable yourself.

    - **standard mode** (``collapsed=False``): every coefficient ``c_k`` has
      the primal's shape ``S``.
    - **collapsed mode** (``collapsed=True``): coefficients ``c_1..c_{K-1}``
      have shape ``(R, *S)`` carrying ``R`` directions; ``c_K`` has shape ``S``
      (already summed over the directions). The K-th output coefficient is
      likewise returned collapsed. This exploits that the highest-order
      coefficient enters linearly, so it can be summed eagerly to propagate
      smaller tensors through the graph. Requires ``K >= 2`` and ``R`` may
      vary per call.

    Args:
        f: Function to overload. May accept and return pytrees of tensors.
        mock_args: Mock input tensors (or pytrees of tensors) for tracing
            ``f``'s compute graph, provided as a tuple matching the positional
            arguments of ``f``. Only shapes and dtypes matter, not the values.
        collapsed: Select between the two propagation regimes above. Default:
            ``False`` (standard mode).

    Returns:
        A callable ``jet_f(*args)`` taking one positional argument per
        argument of ``f``. Each argument is a pytree mirroring the
        corresponding ``mock_args`` entry but with every tensor leaf
        replaced by a tuple ``(primal, c_1, ..., c_K)`` bundling the primal
        with its ``K`` Taylor coefficients. Returns a pytree mirroring ``f``'s
        output structure with each tensor leaf replaced by
        ``(f_0, f_1, ..., f_K)``.

    Examples:
        **Single-input**::

            >>> from torch import sin, zeros, Tensor
            >>> from jet import jet
            >>> jet_f = jet(sin, (zeros(1),))
            >>> x0, x1, x2 = Tensor([0.123]), Tensor([-0.456]), Tensor([0.789])
            >>> f0, f1, f2 = jet_f((x0, x1, x2))

        **Multi-input**::

            >>> from torch import cos
            >>> f = lambda x, y: sin(x) * cos(y)
            >>> jet_f = jet(f, (zeros(3), zeros(3)))
            >>> x, y = Tensor([0.1, 0.2, 0.3]), Tensor([0.4, 0.5, 0.6])
            >>> vx, vy = Tensor([1.0, 0.0, 0.0]), Tensor([0.0, 1.0, 0.0])
            >>> f0, f1 = jet_f((x, vx), (y, vy))
    """
    mod, _ = capture_graph(f, mock_args)
    interp = JetInterpreter(mod, collapsed=collapsed)

    def transformed(*args: PyTree[Jet]) -> PyTree[Jet]:
        leaves, K, R = validate_input_jet(mock_args, args, collapsed=collapsed)
        return interp.run(K, R, *leaves)

    return transformed

jet.laplacian.laplacian

laplacian(
    f: Callable[[Tensor], Tensor],
    mock_args: tuple[PyTree[Tensor], ...],
    randomization: tuple[str, int] | None = None,
    weighting: tuple[Callable[[Tensor, Tensor], Tensor], int] | None = None,
    collapsed: bool = True,
) -> Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

Transform f into a function that computes lap(f(x)).

The Laplacian of a function \(f(\mathbf{x}) \in \mathbb{R}\) with \(\mathbf{x} \in \mathbb{R}^D\) is defined as the Hessian trace, or

\[ \Delta f(\mathbf{x}) = \sum_{d=1}^D \frac{\partial^2 f(\mathbf{x})}{\partial x_d^2} \in \mathbb{R}\,. \]

For functions that produce vectors or tensors, the Laplacian is defined per output component and has the same shape as \(f(\mathbf{x})\).

Only single-tensor functions (one tensor in, one tensor out) are supported.

Parameters:

  • f (Callable[[Tensor], Tensor]) –

    The function whose Laplacian is computed. Must consume and return a single tensor.

  • mock_args (tuple[PyTree[Tensor], ...]) –

    Mock positional arguments for tracing f, provided as a tuple matching f's positional arguments. Does not need to be the actual input; only shapes and dtypes matter. Currently must be a one-tuple of a single tensor.

  • randomization (tuple[str, int] | None, default: None ) –

    Optional tuple containing the distribution type and number of samples for randomized Laplacian. If provided, the Laplacian will be computed using Monte-Carlo sampling. The first element is the distribution type (e.g., 'normal', 'rademacher'), and the second is the number of samples to use.

  • weighting (tuple[Callable[[Tensor, Tensor], Tensor], int] | None, default: None ) –

    A tuple specifying how the second-order derivatives should be weighted. This is described by a coefficient tensor C(x) of shape [*D, *D]. The first entry is a function (x, V) -> V @ S(x).T that applies the symmetric factorization S(x) of the weights C(x) = S(x) @ S(x).T at the input x to the matrix V. S(x) has shape [*D, rank_C] while V is [K, rank_C] with arbitrary K. The second entry specifies rank_C. If None, then the weightings correspond to the identity matrix (i.e. computing the standard Laplacian).

  • collapsed (bool, default: True ) –

    Whether to use collapsed Taylor mode. If True (default), uses the collapsed dispatch path (JetInterpreter(..., collapsed=True)) that directly propagates the summed second-order coefficient. If False, propagates full 2-jets over all directions via vmap and sums afterward.

Returns:

  • Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

    A plain Python callable lap_f(*args) that maps x → lap(f(x)).

  • Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

    To bake the operator into an FX GraphModule (for graph passes,

  • Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

    torch.compile, etc.), apply :func:capture_graph yourself.

Examples:

>>> from torch import manual_seed, rand, zeros
>>> from torch.func import hessian
>>> from torch.nn import Linear, Tanh, Sequential
>>> from jet.laplacian import laplacian
>>> _ = manual_seed(0) # make deterministic
>>> f = Sequential(Linear(3, 1), Tanh())
>>> x0 = rand(3)
>>> # Compute the Laplacian via Taylor mode
>>> lap = laplacian(f, (zeros(3),))(x0)
>>> assert lap.shape == f(x0).shape
>>> # Compute the Laplacian with PyTorch's autodiff (Hessian trace)
>>> lap_pt = hessian(f)(x0).squeeze(0).trace().unsqueeze(0)
>>> assert lap.shape == lap_pt.shape
>>> assert lap_pt.allclose(lap)
Source code in jet/laplacian.py
def laplacian(
    f: Callable[[Tensor], Tensor],
    mock_args: tuple[PyTree[Tensor], ...],
    randomization: tuple[str, int] | None = None,
    weighting: tuple[Callable[[Tensor, Tensor], Tensor], int] | None = None,
    collapsed: bool = True,
) -> Callable[[*tuple[PyTree[Tensor], ...]], Tensor]:
    r"""Transform f into a function that computes lap(f(x)).

    The Laplacian of a function $f(\mathbf{x}) \in \mathbb{R}$ with
    $\mathbf{x} \in \mathbb{R}^D$ is defined as the Hessian trace, or

    $$
    \Delta f(\mathbf{x})
    =
    \sum_{d=1}^D
    \frac{\partial^2 f(\mathbf{x})}{\partial x_d^2} \in \mathbb{R}\,.
    $$

    For functions that produce vectors or tensors, the Laplacian
    is defined per output component and has the same shape as $f(\mathbf{x})$.

    Only single-tensor functions (one tensor in, one tensor out) are supported.

    Args:
        f: The function whose Laplacian is computed. Must consume and return a
            single tensor.
        mock_args: Mock positional arguments for tracing ``f``, provided as a
            tuple matching ``f``'s positional arguments. Does not need to be the
            actual input; only shapes and dtypes matter. Currently must be a
            one-tuple of a single tensor.
        randomization: Optional tuple containing the distribution type and number
            of samples for randomized Laplacian. If provided, the Laplacian will
            be computed using Monte-Carlo sampling. The first element is the
            distribution type (e.g., 'normal', 'rademacher'), and the second is the
            number of samples to use.
        weighting: A tuple specifying how the second-order derivatives should be
            weighted. This is described by a coefficient tensor C(x) of shape
            `[*D, *D]`. The first entry is a function (x, V) -> V @ S(x).T that
            applies the symmetric factorization S(x) of the weights
            C(x) = S(x) @ S(x).T at the input x to the matrix V. S(x) has shape
            `[*D, rank_C]` while V is `[K, rank_C]` with arbitrary `K`. The second
            entry specifies `rank_C`. If `None`, then the weightings correspond to
            the identity matrix (i.e. computing the standard Laplacian).
        collapsed: Whether to use collapsed Taylor mode. If ``True``
            (default), uses the collapsed dispatch path
            (``JetInterpreter(..., collapsed=True)``) that directly propagates
            the summed second-order coefficient. If ``False``, propagates full
            2-jets over all directions via ``vmap`` and sums afterward.

    Returns:
        A plain Python callable ``lap_f(*args)`` that maps ``x → lap(f(x))``.
        To bake the operator into an FX ``GraphModule`` (for graph passes,
        ``torch.compile``, etc.), apply :func:`capture_graph` yourself.

    Examples:
        >>> from torch import manual_seed, rand, zeros
        >>> from torch.func import hessian
        >>> from torch.nn import Linear, Tanh, Sequential
        >>> from jet.laplacian import laplacian
        >>> _ = manual_seed(0) # make deterministic
        >>> f = Sequential(Linear(3, 1), Tanh())
        >>> x0 = rand(3)
        >>> # Compute the Laplacian via Taylor mode
        >>> lap = laplacian(f, (zeros(3),))(x0)
        >>> assert lap.shape == f(x0).shape
        >>> # Compute the Laplacian with PyTorch's autodiff (Hessian trace)
        >>> lap_pt = hessian(f)(x0).squeeze(0).trace().unsqueeze(0)
        >>> assert lap.shape == lap_pt.shape
        >>> assert lap_pt.allclose(lap)
    """
    mock_x = require_single_tensor_input(mock_args, "laplacian")
    in_shape = mock_x.shape
    in_dim = mock_x.numel()

    rank_weightings = in_dim if weighting is None else weighting[1]

    validate_randomization(randomization, SUPPORTED_DISTRIBUTIONS)

    num_jets = rank_weightings if randomization is None else randomization[1]
    apply_weightings = (
        (lambda x, V: V.reshape(num_jets, *in_shape))
        if weighting is None
        else weighting[0]
    )

    cjet_f = (
        jet(f, mock_args, collapsed=True)
        if collapsed
        else _uncollapsed_via_vmap(f, mock_args, randomization)
    )

    def lap_f(*args: PyTree[Tensor]) -> Tensor:
        """Compute the (weighted and/or randomized) Laplacian of f at x.

        Args:
            *args: Positional arguments for ``f`` (currently a single tensor
                matching the mock input's shape).

        Returns:
            The (weighted and/or randomized) Laplacian. Has the same shape as
                ``f(x)``.

        Raises:
            ValueError: If the input shape does not match the mock input shape.
        """
        (x,) = args
        if x.shape != in_shape:
            raise ValueError(f"Expected input shape {in_shape}, got {x.shape}.")

        # Set up first Taylor coefficients
        shape = (num_jets, rank_weightings)
        in_meta = {"dtype": x.dtype, "device": x.device}
        V = (
            eye(rank_weightings, **in_meta)
            if randomization is None
            else sample(x, randomization[0], shape)
        )
        X1 = apply_weightings(x, V)
        z = zeros_like(x)

        _, _, F2 = require_single_tensor_output(cjet_f((x, X1, z)), "laplacian")

        if randomization is not None:
            monte_carlo_scaling = 1.0 / randomization[1]
            F2 = F2 * monte_carlo_scaling

        return F2

    return lap_f

jet.bilaplacian.bilaplacian

bilaplacian(
    f: Callable[[Tensor], Tensor],
    mock_args: tuple[PyTree[Tensor], ...],
    randomization: tuple[str, int] | None = None,
    collapsed: bool = True,
) -> Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

Transform f into a function that computes the Bi-Laplacian.

The Bi-Laplacian of a function \(f(\mathbf{x}) \in \mathbb{R}\) with \(\mathbf{x} \in \mathbb{R}^D\) is defined as the Laplacian of the Laplacian, or

\[ \Delta^2 f(\mathbf{x}) = \sum_{i=1}^D \sum_{j=1}^D \frac{\partial^4 f(\mathbf{x})}{\partial x_i^2 \partial x_j^2} \in \mathbb{R}\,. \]

For functions that produce vectors or tensors, the Bi-Laplacian is defined per output component and has the same shape as \(f(\mathbf{x})\).

Only single-tensor functions (one tensor in, one tensor out) are supported.

Parameters:

  • f (Callable[[Tensor], Tensor]) –

    The function whose Bi-Laplacian is computed. Must consume and return a single tensor.

  • mock_args (tuple[PyTree[Tensor], ...]) –

    Mock positional arguments for tracing f, provided as a tuple matching f's positional arguments. Only shapes and dtypes matter, not the values. Currently must be a one-tuple of a single tensor.

  • randomization (tuple[str, int] | None, default: None ) –

    Optional tuple containing the distribution type and number of samples for randomized Bi-Laplacian. If provided, the Bi-Laplacian will be computed using Monte-Carlo sampling. The first element is the distribution type (must be 'normal'), and the second is the number of samples to use. Default is None.

  • collapsed (bool, default: True ) –

    Whether to use collapsed Taylor mode. If True (default), uses the collapsed dispatch path (JetInterpreter(..., collapsed=True)) that directly propagates the summed fourth-order coefficient. If False, propagates full 4-jets over all directions via vmap and sums afterward.

Returns:

  • Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

    A plain Python callable bilap_f(*args) that maps x → bilap(f(x)).

  • Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

    To bake the operator into an FX GraphModule (for graph passes,

  • Callable[[*(tuple[PyTree[Tensor], ...])], Tensor]

    torch.compile, etc.), apply :func:capture_graph yourself.

Examples:

>>> from torch import manual_seed, rand, zeros
>>> from torch.func import hessian
>>> from torch.nn import Linear, Tanh, Sequential
>>> from jet.bilaplacian import bilaplacian
>>> _ = manual_seed(0) # make deterministic
>>> f = Sequential(Linear(3, 1), Tanh())
>>> x0 = rand(3)
>>> # Compute the Bilaplacian via Taylor mode
>>> bilap = bilaplacian(f, (zeros(3),))(x0)
>>> assert bilap.shape == f(x0).shape
>>> # Compute the Bilaplacian with PyTorch's autodiff
>>> laplacian_pt = lambda x: hessian(f)(x).squeeze(0).trace().unsqueeze(0)
>>> bilaplacian_pt = hessian(laplacian_pt)(x0).squeeze(0).trace().unsqueeze(0)
>>> assert bilap.shape == bilaplacian_pt.shape
>>> assert bilaplacian_pt.allclose(bilap)
Source code in jet/bilaplacian.py
def bilaplacian(
    f: Callable[[Tensor], Tensor],
    mock_args: tuple[PyTree[Tensor], ...],
    randomization: tuple[str, int] | None = None,
    collapsed: bool = True,
) -> Callable[[*tuple[PyTree[Tensor], ...]], Tensor]:
    r"""Transform f into a function that computes the Bi-Laplacian.

    The Bi-Laplacian of a function $f(\mathbf{x}) \in \mathbb{R}$ with
    $\mathbf{x} \in \mathbb{R}^D$ is defined as the Laplacian of the Laplacian, or

    $$
    \Delta^2 f(\mathbf{x})
    =
    \sum_{i=1}^D \sum_{j=1}^D
    \frac{\partial^4 f(\mathbf{x})}{\partial x_i^2 \partial x_j^2} \in \mathbb{R}\,.
    $$

    For functions that produce vectors or tensors, the Bi-Laplacian
    is defined per output component and has the same shape as $f(\mathbf{x})$.

    Only single-tensor functions (one tensor in, one tensor out) are supported.

    Args:
        f: The function whose Bi-Laplacian is computed. Must consume and return
            a single tensor.
        mock_args: Mock positional arguments for tracing ``f``, provided as a
            tuple matching ``f``'s positional arguments. Only shapes and dtypes
            matter, not the values. Currently must be a one-tuple of a single
            tensor.
        randomization: Optional tuple containing the distribution type and number
            of samples for randomized Bi-Laplacian. If provided, the Bi-Laplacian
            will be computed using Monte-Carlo sampling. The first element is the
            distribution type (must be 'normal'), and the second is the number of
            samples to use. Default is `None`.
        collapsed: Whether to use collapsed Taylor mode. If ``True``
            (default), uses the collapsed dispatch path
            (``JetInterpreter(..., collapsed=True)``) that directly propagates
            the summed fourth-order coefficient. If ``False``, propagates full
            4-jets over all directions via ``vmap`` and sums afterward.

    Returns:
        A plain Python callable ``bilap_f(*args)`` that maps ``x → bilap(f(x))``.
        To bake the operator into an FX ``GraphModule`` (for graph passes,
        ``torch.compile``, etc.), apply :func:`capture_graph` yourself.

    Examples:
        >>> from torch import manual_seed, rand, zeros
        >>> from torch.func import hessian
        >>> from torch.nn import Linear, Tanh, Sequential
        >>> from jet.bilaplacian import bilaplacian
        >>> _ = manual_seed(0) # make deterministic
        >>> f = Sequential(Linear(3, 1), Tanh())
        >>> x0 = rand(3)
        >>> # Compute the Bilaplacian via Taylor mode
        >>> bilap = bilaplacian(f, (zeros(3),))(x0)
        >>> assert bilap.shape == f(x0).shape
        >>> # Compute the Bilaplacian with PyTorch's autodiff
        >>> laplacian_pt = lambda x: hessian(f)(x).squeeze(0).trace().unsqueeze(0)
        >>> bilaplacian_pt = hessian(laplacian_pt)(x0).squeeze(0).trace().unsqueeze(0)
        >>> assert bilap.shape == bilaplacian_pt.shape
        >>> assert bilaplacian_pt.allclose(bilap)
    """
    mock_x = require_single_tensor_input(mock_args, "bilaplacian")
    in_shape = mock_x.shape
    in_dim = mock_x.numel()

    validate_randomization(randomization, SUPPORTED_DISTRIBUTIONS)

    cjet_f = (
        jet(f, mock_args, collapsed=True)
        if collapsed
        else _uncollapsed_via_vmap(f, mock_args, randomization)
    )

    def _eval_4jet(x: Tensor, X1: Tensor) -> Tensor:
        """Evaluate the 4-jet for directions X1 and return the 4th coefficient.

        Args:
            x: Input tensor.
            X1: First-order directions, shape (R, *in_shape).

        Returns:
            The (collapsed or summed) 4th-order coefficient.
        """
        z = zeros_like(x)
        R = X1.shape[0]
        Z = zeros(R, *in_shape, dtype=x.dtype, device=x.device)
        result = require_single_tensor_output(cjet_f((x, X1, Z, Z, z)), "bilaplacian")
        _, _, _, _, F4 = result
        return F4

    def bilap_f(*args: PyTree[Tensor]) -> Tensor:
        """Compute the Bi-Laplacian of the function at the input tensor.

        Args:
            *args: Positional arguments for ``f`` (currently a single tensor
                matching the mock input's shape).

        Returns:
            The Bi-Laplacian. Has the same shape as f(x).

        Raises:
            ValueError: If the input shape does not match the expected shape.
        """
        (x,) = args
        if x.shape != in_shape:
            raise ValueError(f"Expected input shape {in_shape}, got {x.shape}.")

        if randomization is not None:
            distribution, num_samples = randomization
            X1 = sample(x, distribution, (num_samples, *in_shape))
            F4 = _eval_4jet(x, X1)
            return F4 / (3 * num_samples)

        C1, C2, C3 = _set_up_taylor_coefficients(x)
        D = in_dim

        gamma_4_4 = float(compute_all_gammas((4,))[(4,)])
        gammas = compute_all_gammas((2, 2))
        gamma_4_0 = float(gammas[(4, 0)])
        F4_1 = _eval_4jet(x, C1)
        factor1 = (gamma_4_4 + 2 * (D - 1) * gamma_4_0) / 24
        term1 = factor1 * F4_1

        if D == 1:
            return term1

        gamma_3_1 = float(gammas[(3, 1)])
        F4_2 = _eval_4jet(x, C2)
        term2 = 2 * gamma_3_1 / 24 * F4_2

        gamma_2_2 = float(gammas[(2, 2)])
        F4_3 = _eval_4jet(x, C3)
        term3 = 2 * gamma_2_2 / 24 * F4_3

        return term1 + term2 + term3

    return bilap_f

Graph capture

jet.capture_graph

capture_graph(
    f: Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]],
    mock_args: tuple[PyTree[Tensor], ...],
) -> tuple[GraphModule, TreeSpec]

Capture the compute graph of f as a GraphModule.

The returned GraphModule's forward takes the flat tensor leaves of mock_args as positional arguments -- make_fx cannot trace through pytree containers, so mock_args is flattened and f is traced over those flat leaves. To call the captured graph with f's original pytree shape, use the second return value, in_spec, to flatten new arguments in the order the graph expects::

mod, in_spec = capture_graph(f, mock_args)
out = mod(*in_spec.flatten_up_to(args))

Output structure is passed through unchanged -- whatever f returns (single tensor, tuple, dict, arbitrary pytree), make_fx's pytree codegen reconstructs on each call.

Parameters:

  • f (Callable[[*(tuple[PyTree[Tensor], ...])], PyTree[Tensor]]) –

    Callable to trace (plain function, nn.Module, GraphModule, etc.). May accept pytrees of tensors as positional args.

  • mock_args (tuple[PyTree[Tensor], ...]) –

    Mock inputs (pytrees of tensors) matching f's positional args, provided as a tuple. Only shapes and dtypes matter.

Returns:

  • GraphModule

    A pair (mod, in_spec) where mod is a GraphModule whose

  • TreeSpec

    forward takes the flat tensor leaves of mock_args, and

  • tuple[GraphModule, TreeSpec]

    in_spec is the TreeSpec of mock_args. Use

  • tuple[GraphModule, TreeSpec]

    in_spec.flatten_up_to(args) to obtain the flat leaves in the

  • tuple[GraphModule, TreeSpec]

    order mod expects.

Raises:

  • TypeError

    If mock_args is not a tuple. Wrap a single positional argument as (x,).

Source code in jet/tracing.py
def capture_graph(
    f: Callable[[*tuple[PyTree[Tensor], ...]], PyTree[Tensor]],
    mock_args: tuple[PyTree[Tensor], ...],
) -> tuple[GraphModule, TreeSpec]:
    """Capture the compute graph of ``f`` as a ``GraphModule``.

    The returned ``GraphModule``'s ``forward`` takes the flat tensor leaves
    of ``mock_args`` as positional arguments -- ``make_fx`` cannot trace
    through pytree containers, so ``mock_args`` is flattened and ``f`` is
    traced over those flat leaves. To call the captured graph with ``f``'s
    original pytree shape, use the second return value, ``in_spec``, to
    flatten new arguments in the order the graph expects::

        mod, in_spec = capture_graph(f, mock_args)
        out = mod(*in_spec.flatten_up_to(args))

    Output structure is passed through unchanged -- whatever ``f`` returns
    (single tensor, tuple, dict, arbitrary pytree), ``make_fx``'s pytree
    codegen reconstructs on each call.

    Args:
        f: Callable to trace (plain function, ``nn.Module``, ``GraphModule``,
            etc.). May accept pytrees of tensors as positional args.
        mock_args: Mock inputs (pytrees of tensors) matching ``f``'s positional
            args, provided as a tuple. Only shapes and dtypes matter.

    Returns:
        A pair ``(mod, in_spec)`` where ``mod`` is a ``GraphModule`` whose
        forward takes the flat tensor leaves of ``mock_args``, and
        ``in_spec`` is the ``TreeSpec`` of ``mock_args``. Use
        ``in_spec.flatten_up_to(args)`` to obtain the flat leaves in the
        order ``mod`` expects.

    Raises:
        TypeError: If ``mock_args`` is not a ``tuple``. Wrap a single
            positional argument as ``(x,)``.
    """
    if not isinstance(mock_args, tuple):
        raise TypeError(
            f"mock_args must be a tuple of f's positional arguments, got "
            f"{type(mock_args).__name__}; wrap a single argument as ``(x,)``."
        )
    _assert_traceable_signature(mock_args)
    flat_mocks, in_spec = tree_flatten(mock_args)

    def flat_f(*flat_tensors: Tensor) -> PyTree[Tensor]:
        return f(*tree_unflatten(list(flat_tensors), in_spec))

    mod = _make_fx(functionalize(flat_f))(*flat_mocks)
    _replace_inplace_ops(mod)
    mod.graph.eliminate_dead_code()
    mod.recompile()
    return mod, in_spec

jet.simplify.common_subexpression_elimination

common_subexpression_elimination(graph: Graph, verbose: bool = False) -> bool

Replace duplicate subexpressions with a single node.

Parameters:

  • graph (Graph) –

    The graph to be optimized.

  • verbose (bool, default: False ) –

    Whether to print debug information. Default: False.

Returns:

  • bool

    Whether a subexpression was replaced.

Source code in jet/simplify.py
def common_subexpression_elimination(graph: Graph, verbose: bool = False) -> bool:
    """Replace duplicate subexpressions with a single node.

    Args:
        graph: The graph to be optimized.
        verbose: Whether to print debug information. Default: `False`.

    Returns:
        Whether a subexpression was replaced.
    """
    nodes = {}

    replaced = False
    num_replacements = 0

    for node in list(graph.nodes):
        node_hash = (node.op, node.target, node.args, node.kwargs)
        if node_hash in nodes:
            # replace the node
            replacement = nodes[node_hash]
            if verbose:
                print(
                    f"Replacing {node}"
                    + f" ({node.op}, {node.target}, {node.args}, {node.kwargs})\nwith"
                    + f" {replacement} ({replacement.op}, {replacement.target},"
                    + f" {replacement.args}, {replacement.kwargs})"
                )
            node.replace_all_uses_with(replacement)

            replaced = True
            num_replacements += 1
        else:
            nodes[node_hash] = node

    if replaced:
        graph.eliminate_dead_code()

    if verbose:
        print(f"Replacements: {num_replacements}")

    return replaced

jet.utils.visualize_graph

visualize_graph(
    mod: GraphModule, savefile: str, name: str = "", use_custom: bool = False
)

Visualize the compute graph of a module.

Supported formats: .png, .pdf, .svg (inferred from savefile).

Parameters:

  • mod (GraphModule) –

    The module whose compute graph to visualize.

  • savefile (str) –

    The path to the file where the graph should be saved.

  • name (str, default: '' ) –

    A name for the graph, used in the visualization.

  • use_custom (bool, default: False ) –

    If True, highlight sum nodes in orange-red and use white for other operations. Defaults to False.

Raises:

  • ValueError

    If savefile has an unsupported extension.

Source code in jet/utils.py
def visualize_graph(
    mod: GraphModule, savefile: str, name: str = "", use_custom: bool = False
):
    """Visualize the compute graph of a module.

    Supported formats: ``.png``, ``.pdf``, ``.svg`` (inferred from *savefile*).

    Args:
        mod: The module whose compute graph to visualize.
        savefile: The path to the file where the graph should be saved.
        name: A name for the graph, used in the visualization.
        use_custom: If ``True``, highlight sum nodes in orange-red and use white
            for other operations. Defaults to ``False``.

    Raises:
        ValueError: If *savefile* has an unsupported extension.
    """
    cls = _CustomDrawer if use_custom else FxGraphDrawer
    drawer = cls(mod, name)
    dot_graph = drawer.get_dot_graph()

    creators = {
        ".png": dot_graph.create_png,
        ".pdf": dot_graph.create_pdf,
        ".svg": dot_graph.create_svg,
    }

    suffix = Path(savefile).suffix.lower()
    creator = creators.get(suffix)
    if creator is None:
        supported = ", ".join(sorted(creators))
        raise ValueError(f"Unsupported file format {suffix!r}. Use one of: {supported}")

    with open(savefile, "wb") as f:
        f.write(creator())