ExportDB¶
ExportDB is a centralized dataset of supported and unsupported export cases. It is targeted towards users who want to understand specifically what types of code are supported, the subtleties of export, and how to modify their existing code to be compatible with export. Note that this is not an exhaustive set of everything that is supported by exportdb, but it covers the most common and confusing use cases that users will run into.
If you have a feature that you think needs a stronger guarantee from us to support in export please create an issue in the pytorch/pytorch repo wih a module:export tag.
Supported¶
assume_constant_result¶
Original source code:
import torch
import torch._dynamo as torchdynamo
class AssumeConstantResult(torch.nn.Module):
"""
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
"""
def __init__(self):
super().__init__()
@torchdynamo.assume_constant_result
def get_item(self, y):
return y.int().item()
def forward(self, x, y):
return x[: self.get_item(y)]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
slice_tensor: f32[3, 2] = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4); arg0_1 = None
return (slice_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['slice_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
autograd_function¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x.clone()
@staticmethod
def backward(ctx, grad_output):
return grad_output + 1
class AutogradFunction(torch.nn.Module):
"""
TorchDynamo does not keep track of backward() on autograd functions. We recommend to
use `allow_in_graph` to mitigate this problem.
"""
def forward(self, x):
return MyAutogradFunction.apply(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
clone_default: f32[3, 2] = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
return (clone_default,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['clone_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
class_method¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
import torch
class ClassMethod(torch.nn.Module):
"""
Class methods are inlined during tracing.
"""
@classmethod
def method(cls, x):
return x + 1
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 2)
def forward(self, x):
x = self.linear(x)
return self.method(x) * self.__class__.method(x) * type(self).method(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[2, 4], arg1_1: f32[2], arg2_1: f32[3, 4]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg1_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg2_1, 0)
sym_size_int_4 = torch.ops.aten.sym_size.int(arg2_1, 1)
eq = sym_size_int_4 == 4; sym_size_int_4 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg2_1.shape[1] is specialized at 4'); scalar_tensor_default = None
eq_1 = sym_size_int_3 == 3; sym_size_int_3 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg2_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
eq_2 = sym_size_int_2 == 2; sym_size_int_2 = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default_2 = None
eq_3 = sym_size_int_1 == 4; sym_size_int_1 = None
scalar_tensor_default_3: f32[] = torch.ops.aten.scalar_tensor.default(eq_3); eq_3 = None
_assert_async_msg_3 = torch.ops.aten._assert_async.msg(scalar_tensor_default_3, 'Input arg0_1.shape[1] is specialized at 4'); scalar_tensor_default_3 = None
eq_4 = sym_size_int == 2; sym_size_int = None
scalar_tensor_default_4: f32[] = torch.ops.aten.scalar_tensor.default(eq_4); eq_4 = None
_assert_async_msg_4 = torch.ops.aten._assert_async.msg(scalar_tensor_default_4, 'Input arg0_1.shape[0] is specialized at 2'); scalar_tensor_default_4 = None
permute_default: f32[4, 2] = torch.ops.aten.permute.default(arg0_1, [1, 0]); arg0_1 = None
addmm_default: f32[3, 2] = torch.ops.aten.addmm.default(arg1_1, arg2_1, permute_default); arg1_1 = arg2_1 = permute_default = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(addmm_default, 1)
add_tensor_1: f32[3, 2] = torch.ops.aten.add.Tensor(addmm_default, 1)
mul_tensor: f32[3, 2] = torch.ops.aten.mul.Tensor(add_tensor, add_tensor_1); add_tensor = add_tensor_1 = None
add_tensor_2: f32[3, 2] = torch.ops.aten.add.Tensor(addmm_default, 1); addmm_default = None
mul_tensor_1: f32[3, 2] = torch.ops.aten.mul.Tensor(mul_tensor, add_tensor_2); mul_tensor = add_tensor_2 = None
return (mul_tensor_1,)
Graph Signature: ExportGraphSignature(parameters=['L__self___linear.weight', 'L__self___linear.bias'], buffers=[], user_inputs=['arg2_1'], user_outputs=['mul_tensor_1'], inputs_to_parameters={'arg0_1': 'L__self___linear.weight', 'arg1_1': 'L__self___linear.bias'}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
cond_branch_class_method¶
Original source code:
import torch
from functorch.experimental.control_flow import cond
class MySubModule(torch.nn.Module):
def foo(self, x):
return x.cos()
def forward(self, x):
return self.foo(x)
class CondBranchClassMethod(torch.nn.Module):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using class method in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def __init__(self):
super().__init__()
self.subm = MySubModule()
def bar(self, x):
return x.sin()
def forward(self, x):
return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
eq = sym_size_int == 3; sym_size_int = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default = None
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond: f32[3] = torch.ops.cond(False, submodule_0, submodule_1, [arg0_1]); submodule_0 = submodule_1 = arg0_1 = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
cos_default: f32[3] = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return cos_default
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
sin_default: f32[3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin_default
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
cond_branch_nested_function¶
Original source code:
import torch
from functorch.experimental.control_flow import cond
def cond_branch_nested_function(x):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates using nested function in cond().
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def true_fn(x):
def inner_true_fn(y):
return x + y
return inner_true_fn(x)
def false_fn(x):
def inner_false_fn(y):
return x - y
return inner_false_fn(x)
return cond(x.shape[0] < 10, true_fn, false_fn, [x])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
eq = sym_size_int == 3; sym_size_int = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default = None
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond: f32[3] = torch.ops.cond(True, submodule_0, submodule_1, [arg0_1]); submodule_0 = submodule_1 = arg0_1 = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
add_tensor: f32[3] = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
return add_tensor
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3]):
sub_tensor: f32[3] = torch.ops.aten.sub.Tensor(arg0_1, arg0_1); arg0_1 = None
return sub_tensor
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
cond_branch_nonlocal_variables¶
Original source code:
import torch
from functorch.experimental.control_flow import cond
def cond_branch_nonlocal_variables(x):
"""
The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
- both branches must take the same args, which must also match the branch args passed to cond.
- both branches must return a single tensor
- returned tensor must have the same tensor metadata, e.g. shape and dtype
- branch function can be free function, nested function, lambda, class methods
- branch function can not have closure variables
- no inplace mutations on inputs or global variables
This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.
The code below will not work because capturing closure variables is not supported.
```
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y + my_tensor_var + my_primitive_var
def false_fn(y):
nonlocal my_tensor_var, my_primitive_var
return y - my_tensor_var - my_primitive_var
return cond(x.shape[0] > 5, true_fn, false_fn, [x])
```
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
my_tensor_var = x + 100
my_primitive_var = 3.14
def true_fn(x, y, z):
return x + y + z
def false_fn(x, y, z):
return x - y - z
return cond(
x.shape[0] > 5,
true_fn,
false_fn,
[x, my_tensor_var, torch.tensor(my_primitive_var)],
)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[6]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
eq = sym_size_int == 6; sym_size_int = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 6'); scalar_tensor_default = None
add_tensor: f32[6] = torch.ops.aten.add.Tensor(arg0_1, 100)
_tensor_constant0: f32[] = self._tensor_constant0
lift_fresh_copy_default: f32[] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond: f32[6] = torch.ops.cond(True, submodule_0, submodule_1, [arg0_1, add_tensor, lift_fresh_copy_default]); submodule_0 = submodule_1 = arg0_1 = add_tensor = lift_fresh_copy_default = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]):
add_tensor: f32[6] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
add_tensor_1: f32[6] = torch.ops.aten.add.Tensor(add_tensor, arg2_1); add_tensor = arg2_1 = None
return add_tensor_1
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]):
sub_tensor: f32[6] = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
sub_tensor_1: f32[6] = torch.ops.aten.sub.Tensor(sub_tensor, arg2_1); sub_tensor = arg2_1 = None
return sub_tensor_1
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
cond_closed_over_variable¶
Original source code:
import torch
from functorch.experimental.control_flow import cond
class CondClosedOverVariable(torch.nn.Module):
"""
torch.cond() supports branches closed over arbitrary variables.
"""
def forward(self, pred, x):
def true_fn(val):
return x * 2
def false_fn(val):
return x - 2
return cond(pred, true_fn, false_fn, [x + 1])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: b8[], arg1_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg1_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg1_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg1_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg1_1, 1)
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond: f32[3, 2] = torch.ops.cond(arg0_1, submodule_0, submodule_1, [add_tensor, arg1_1, arg1_1]); arg0_1 = submodule_0 = submodule_1 = add_tensor = arg1_1 = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2], arg2_1: f32[3, 2]):
mul_tensor: f32[3, 2] = torch.ops.aten.mul.Tensor(arg2_1, 2); arg2_1 = None
return mul_tensor
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2], arg2_1: f32[3, 2]):
sub_tensor: f32[3, 2] = torch.ops.aten.sub.Tensor(arg2_1, 2); arg2_1 = None
return sub_tensor
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
cond_operands¶
Original source code:
import torch
from torch._export import dynamic_dim
from functorch.experimental.control_flow import cond
x = torch.randn(3, 2)
y = torch.ones(2)
dynamic_constraint = dynamic_dim(x, 0)
def cond_operands(x, y):
"""
The operands passed to cond() must be:
- a list of tensors
- match arguments of `true_fn` and `false_fn`
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
def true_fn(x, y):
return x + y
def false_fn(x, y):
return x - y
return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
#
sym_size_int: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg1_1, 0)
eq = sym_size_int_2 == 2; sym_size_int_2 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 = None
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 0)
gt: Sym(s0 > 2) = sym_size > 2; sym_size = None
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond: f32[s0, 2] = torch.ops.cond(gt, submodule_0, submodule_1, [arg0_1, arg1_1]); gt = submodule_0 = submodule_1 = arg0_1 = arg1_1 = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
add_tensor: f32[s0, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return add_tensor
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
sub_tensor: f32[s0, 2] = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return sub_tensor
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {s0: RangeConstraint(min_val=2, max_val=oo)}
cond_predicate¶
Original source code:
import torch
from functorch.experimental.control_flow import cond
def cond_predicate(x):
"""
The conditional statement (aka predicate) passed to cond() must be one of the following:
- torch.Tensor with a single element
- boolean expression
NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
"""
pred = x.dim() > 2 and x.shape[2] > 10
return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[6, 4, 3]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg0_1, 2)
eq = sym_size_int_2 == 3; sym_size_int_2 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 3'); scalar_tensor_default = None
eq_1 = sym_size_int_1 == 4; sym_size_int_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 4'); scalar_tensor_default_1 = None
eq_2 = sym_size_int == 6; sym_size_int = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 6'); scalar_tensor_default_2 = None
submodule_0 = self.submodule_0
submodule_1 = self.submodule_1
cond: f32[6, 4, 3] = torch.ops.cond(False, submodule_0, submodule_1, [arg0_1]); submodule_0 = submodule_1 = arg0_1 = None
return (cond,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[6, 4, 3]):
cos_default: f32[6, 4, 3] = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return cos_default
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[6, 4, 3]):
sin_default: f32[6, 4, 3] = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
return sin_default
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cond'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
decorator¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
import functools
import torch
def test_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs) + 1
return wrapper
class Decorator(torch.nn.Module):
"""
Decorators calls are inlined into the exported function during tracing.
"""
@test_decorator
def forward(self, x, y):
return x + y
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg1_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg1_1, 1)
eq = sym_size_int_3 == 2; sym_size_int_3 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_2 == 3; sym_size_int_2 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg1_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
eq_2 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_2 = None
eq_3 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_3: f32[] = torch.ops.aten.scalar_tensor.default(eq_3); eq_3 = None
_assert_async_msg_3 = torch.ops.aten._assert_async.msg(scalar_tensor_default_3, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_3 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
add_tensor_1: f32[3, 2] = torch.ops.aten.add.Tensor(add_tensor, 1); add_tensor = None
return (add_tensor_1,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dictionary¶
Original source code:
import torch
def dictionary(x, y):
"""
Dictionary structures are inlined and flattened along tracing.
"""
elements = {}
elements["x2"] = x * x
y = y * elements["x2"]
return {"y": y}
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
mul_tensor: f32[3, 2] = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
mul_tensor_1: f32[3, 2] = torch.ops.aten.mul.Tensor(arg1_1, mul_tensor); arg1_1 = mul_tensor = None
return (mul_tensor_1,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['mul_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dynamic_shape_assert¶
Original source code:
import torch
def dynamic_shape_assert(x):
"""
A basic usage of python assertion.
"""
# assertion with error message
assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
# assertion without error message
assert x.shape[0] > 1
return x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
return (arg0_1,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['arg0_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dynamic_shape_constructor¶
Original source code:
import torch
def dynamic_shape_constructor(x):
"""
Tensor constructors should be captured with dynamic shape inputs rather
than being baked in with static shape.
"""
return torch.ones(x.shape[0] * 2)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1); arg0_1 = None
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
full_default: f32[6] = torch.ops.aten.full.default([6], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
return (full_default,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['full_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dynamic_shape_if_guard¶
Original source code:
import torch
class DynamicShapeIfGuard(torch.nn.Module):
"""
`if` statement with backed dynamic shape predicate will be specialized into
one particular branch and generate a guard. However, export will fail if the
the dimension is marked as dynamic shape from higher level API.
"""
def forward(self, x):
if x.shape[0] == 3:
return x.cos()
return x.sin()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg0_1, 2)
eq = sym_size_int_2 == 2; sym_size_int_2 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 = None
eq_2 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 = None
cos_default: f32[3, 2, 2] = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
return (cos_default,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['cos_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dynamic_shape_map¶
Original source code:
import torch
from functorch.experimental.control_flow import map
def dynamic_shape_map(xs, y):
"""
functorch map() maps a function over the first tensor dimension.
"""
def body(x, y):
return x + y
return map(body, xs, y)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg1_1, 0)
eq = sym_size_int_2 == 2; sym_size_int_2 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 = None
eq_2 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 = None
submodule_0 = self.submodule_0
map_impl = torch.ops.map_impl(submodule_0, 1, arg0_1, arg1_1); submodule_0 = arg0_1 = arg1_1 = None
getitem: f32[3, 2] = map_impl[0]; map_impl = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return [add_tensor]
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['getitem'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dynamic_shape_slicing¶
Original source code:
import torch
def dynamic_shape_slicing(x):
"""
Slices with dynamic shape arguments should be captured into the graph
rather than being baked in.
"""
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
slice_tensor: f32[1, 2] = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1); arg0_1 = None
slice_tensor_1: f32[1, 1] = torch.ops.aten.slice.Tensor(slice_tensor, 1, 1, 9223372036854775807, 2); slice_tensor = None
return (slice_tensor_1,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['slice_tensor_1'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
dynamic_shape_view¶
Original source code:
import torch
def dynamic_shape_view(x):
"""
Dynamic shapes should be propagated to view arguments instead of being
baked into the exported graph.
"""
new_x_shape = x.size()[:-1] + (2, 5)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[10, 10]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 10; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 10'); scalar_tensor_default = None
eq_1 = sym_size_int == 10; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 10'); scalar_tensor_default_1 = None
view_default: f32[10, 2, 5] = torch.ops.aten.view.default(arg0_1, [10, 2, 5]); arg0_1 = None
permute_default: f32[10, 5, 2] = torch.ops.aten.permute.default(view_default, [0, 2, 1]); view_default = None
return (permute_default,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['permute_default'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
list_contains¶
Original source code:
import torch
def list_contains(x):
"""
List containment relation can be checked on a dynamic shape or constants.
"""
assert x.size(-1) in [6, 2]
assert x.size(0) not in [4, 5, 6]
assert "monkey" not in ["cow", "pig"]
return x + x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
list_unpack¶
Original source code:
from typing import List
import torch
def list_unpack(args: List[torch.Tensor]):
"""
Lists are treated as static construct, therefore unpacking should be
erased after tracing.
"""
x, *y = args
return x + y[0]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: i64[], arg2_1: i64[]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1', 'arg2_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
nested_function¶
Original source code:
import torch
def nested_function(a, b):
"""
Nested functions are traced through. Side effects on global captures
are not supported though.
"""
x = a + b
z = a - b
def closure(y):
nonlocal x
x += 1
return x * y + z
return closure(x)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg1_1, 0)
eq = sym_size_int_2 == 2; sym_size_int_2 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 = None
eq_2 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
sub_tensor: f32[3, 2] = torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
add_tensor_1: f32[3, 2] = torch.ops.aten.add.Tensor(add_tensor, 1); add_tensor = None
mul_tensor: f32[3, 2] = torch.ops.aten.mul.Tensor(add_tensor_1, add_tensor_1); add_tensor_1 = None
add_tensor_2: f32[3, 2] = torch.ops.aten.add.Tensor(mul_tensor, sub_tensor); mul_tensor = sub_tensor = None
return (add_tensor_2,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add_tensor_2'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
null_context_manager¶
Original source code:
import contextlib
import torch
def null_context_manager(x):
"""
Null context manager in Python will be traced out.
"""
ctx = contextlib.nullcontext()
with ctx:
return x.sin() + x.cos()
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
sin_default: f32[3, 2] = torch.ops.aten.sin.default(arg0_1)
cos_default: f32[3, 2] = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(sin_default, cos_default); sin_default = cos_default = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
pytree_flatten¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
import torch
from torch.utils import _pytree as pytree
def pytree_flatten(x):
"""
Pytree from PyTorch cannot be captured by TorchDynamo.
"""
y, spec = pytree.tree_flatten(x)
return y[0] + 1
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg1_1, 0)
sym_size_int_3 = torch.ops.aten.sym_size.int(arg1_1, 1); arg1_1 = None
eq = sym_size_int_3 == 2; sym_size_int_3 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_2 == 3; sym_size_int_2 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg1_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
eq_2 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_2 = None
eq_3 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_3: f32[] = torch.ops.aten.scalar_tensor.default(eq_3); eq_3 = None
_assert_async_msg_3 = torch.ops.aten._assert_async.msg(scalar_tensor_default_3, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_3 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
scalar_output¶
Original source code:
import torch
from torch._export import dynamic_dim
x = torch.ones(3, 2)
dynamic_constraint = dynamic_dim(x, 1)
def scalar_output(x):
"""
Returning scalar values from the graph is supported, in addition to Tensor
outputs. Symbolic shapes are captured and rank is specialized.
"""
return x.shape[1] + 1
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, s0]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int == 3; sym_size_int = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default = None
sym_size: Sym(s0) = torch.ops.aten.sym_size.int(arg0_1, 1); arg0_1 = None
add: Sym(s0 + 1) = sym_size + 1; sym_size = None
return (add,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {s0: RangeConstraint(min_val=2, max_val=oo)}
specialized_attribute¶
Note
Tags:
Support Level: SUPPORTED
Original source code:
from enum import Enum
import torch
class Animal(Enum):
COW = "moo"
class SpecializedAttribute(torch.nn.Module):
"""
Model attributes are specialized.
"""
def __init__(self):
super().__init__()
self.a = "moo"
self.b = 4
def forward(self, x):
if self.a == Animal.COW.value:
return x * x + self.b
else:
raise ValueError("bad")
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
mul_tensor: f32[3, 2] = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(mul_tensor, 4); mul_tensor = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
static_for_loop¶
Original source code:
import torch
class StaticForLoop(torch.nn.Module):
"""
A for loop with constant number of iterations should be unrolled in the exported graph.
"""
def __init__(self):
super().__init__()
def forward(self, x):
ret = []
for i in range(10): # constant
ret.append(i + x)
return ret
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 0)
add_tensor_1: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 1)
add_tensor_2: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 2)
add_tensor_3: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 3)
add_tensor_4: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 4)
add_tensor_5: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 5)
add_tensor_6: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 6)
add_tensor_7: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 7)
add_tensor_8: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 8)
add_tensor_9: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 9); arg0_1 = None
return (add_tensor, add_tensor_1, add_tensor_2, add_tensor_3, add_tensor_4, add_tensor_5, add_tensor_6, add_tensor_7, add_tensor_8, add_tensor_9)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add_tensor', 'add_tensor_1', 'add_tensor_2', 'add_tensor_3', 'add_tensor_4', 'add_tensor_5', 'add_tensor_6', 'add_tensor_7', 'add_tensor_8', 'add_tensor_9'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
static_if¶
Original source code:
import torch
class StaticIf(torch.nn.Module):
"""
`if` statement with static predicate value should be traced through with the
taken branch.
"""
def __init__(self):
super().__init__()
def forward(self, x):
if len(x.shape) == 3:
return x + torch.ones(1, 1, 1)
return x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2, 2]):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 = torch.ops.aten.sym_size.int(arg0_1, 2)
eq = sym_size_int_2 == 2; sym_size_int_2 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 = None
eq_2 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_2: f32[] = torch.ops.aten.scalar_tensor.default(eq_2); eq_2 = None
_assert_async_msg_2 = torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 = None
full_default: f32[1, 1, 1] = torch.ops.aten.full.default([1, 1, 1], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add_tensor: f32[3, 2, 2] = torch.ops.aten.add.Tensor(arg0_1, full_default); arg0_1 = full_default = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
tensor_setattr¶
Original source code:
import torch
def tensor_setattr(x, attr):
"""
setattr() call onto tensors is not supported.
"""
setattr(x, attr, torch.randn(3, 2))
return x + 4
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1):
#
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 = torch.ops.aten.sym_size.int(arg0_1, 1)
eq = sym_size_int_1 == 2; sym_size_int_1 = None
scalar_tensor_default: f32[] = torch.ops.aten.scalar_tensor.default(eq); eq = None
_assert_async_msg = torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default = None
eq_1 = sym_size_int == 3; sym_size_int = None
scalar_tensor_default_1: f32[] = torch.ops.aten.scalar_tensor.default(eq_1); eq_1 = None
_assert_async_msg_1 = torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 = None
add_tensor: f32[3, 2] = torch.ops.aten.add.Tensor(arg0_1, 4); arg0_1 = None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters=[], buffers=[], user_inputs=['arg0_1', 'arg1_1'], user_outputs=['add_tensor'], inputs_to_parameters={}, inputs_to_buffers={}, buffers_to_mutate={}, backward_signature=None, assertion_dep_token=None)
Symbol to range: {}
Not Supported Yet¶
dynamic_shape_round¶
Original source code:
import torch
from torch._export import dynamic_dim
x = torch.ones(3, 2)
dynamic_constraint = dynamic_dim(x, 0)
def dynamic_shape_round(x):
"""
Calling round on dynamic shapes is not supported.
"""
return x[: round(x.shape[0] / 2)]
Result:
Unsupported: Calling round() on symbolic value is not supported. You can use floor() to implement this functionality
fn_with_kwargs¶
Original source code:
import torch
def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
"""
Keyword arguments are not supported at the moment.
"""
out = pos0
for arg in tuple0:
out *= arg
for arg in myargs:
out *= arg
out *= mykw0
out *= mykwargs["input0"] * mykwargs["input1"]
return out
Result:
Unsupported: Kwargs to torch.export is not supported
type_reflection_method¶
Original source code:
import torch
class A:
@classmethod
def func(cls, x):
return 1 + x
def type_reflection_method(x):
"""
type() calls on custom objects followed by method calls are not allowed
due to its overly dynamic nature.
"""
a = A()
return type(a).func(x)
Result:
Unsupported: Can't call type() on generated custom object. Please use __class__ instead
You can rewrite the example above to something like the following:
def type_reflection_method_rewrite(x):
"""
Custom object class methods will be inlined.
"""
return A.func(x)