• Docs >
  • torch.func interaction with torch.compile
Shortcuts

torch.func interaction with torch.compile

So you want to use a torch.func (“functorch”) transform (like vmap, grad, jacrev, etc) with torch.compile. Here’s a guide to what works today, what doesn’t, and how to work around it.

Applying a torch.func transform to a torch.compile’d function

This doesn’t work and is being tracked by https://github.com/pytorch/pytorch/issues/100320.

import torch

@torch.compile
def f(x):
    return torch.sin(x)

def g(x):
    return torch.grad(f)(x)

x = torch.randn(2, 3)
g(x)

As a workaround, please put the torch.compile outside of the torch.func transform:

import torch

def f(x):
    return torch.sin(x)

@torch.compile
def g(x):
    return torch.vmap(f)(x)

x = torch.randn(2, 3)
g(x)

Doesn’t work (PT 2.0): calling a torch.func transform inside of a torch.compile’ed function

import torch

@torch.compile
def f(x):
    return torch.vmap(torch.sum)(x)

x = torch.randn(2, 3)
f(x)

This doesn’t work yet. Please see the workaround (the next section).

Workaround: use torch._dynamo.allow_in_graph

allow_in_graph is an escape hatch. If your code does not work with torch.compile, which introspects Python bytecode, but you believe it will work via a symbolic tracing approach (like jax.jit), then use allow_in_graph.

By using allow_in_graph to annotate a function, you promise PyTorch a couple of things that we are unable to completely verify: - Your function is pure. That is, all outputs only depend on the inputs and do not depend on any captured Tensors. - Your function is functional. That is, it does not mutate any state. This may be relaxed; we actually support functions that appear to be functional from the outside: they may have in-place PyTorch operations, but may not mutate global state or inputs to the function. - Your function does not raise data-dependent errors.

import torch

@torch.compile
def f(x):
    return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)

x = torch.randn(2, 3)
f(x)

A common pitfall is using allow_in_graph to annotate a function that invokes an nn.Module. This is because the outputs now depend on the parameters of the nn.Module. To actually get this to work, use torch.func.functional_call to extract the module state.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources