Shortcuts

torch.cond

cond_branch_class_method

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class MySubModule(torch.nn.Module):
    def foo(self, x):
        return x.cos()

    def forward(self, x):
        return self.foo(x)


class CondBranchClassMethod(torch.nn.Module):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables


    This example demonstrates using class method in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def __init__(self):
        super().__init__()
        self.subm - MySubModule()

    def bar(self, x):
        return x.sin()

    def forward(self, x):
        return cond(x.shape[0] <- 2, self.subm.forward, self.bar, [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3]):
            #
            sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
            eq - sym_size_int -- 3;  sym_size_int - None
            scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq);  eq - None
            _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3');  scalar_tensor_default - None
            submodule_0 - self.submodule_0
            submodule_1 - self.submodule_1
            cond: f32[3] - torch.ops.cond(False, submodule_0, submodule_1, [arg0_1]);  submodule_0 - submodule_1 - arg0_1 - None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        cos_default: f32[3] - torch.ops.aten.cos.default(arg0_1);  arg0_1 - None
                return cos_default

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        sin_default: f32[3] - torch.ops.aten.sin.default(arg0_1);  arg0_1 - None
                return sin_default

Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}

cond_branch_nested_function

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


def cond_branch_nested_function(x):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates using nested function in cond().

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def true_fn(x):
        def inner_true_fn(y):
            return x + y

        return inner_true_fn(x)

    def false_fn(x):
        def inner_false_fn(y):
            return x - y

        return inner_false_fn(x)

    return cond(x.shape[0] < 10, true_fn, false_fn, [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[3]):
            #
            sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
            eq - sym_size_int -- 3;  sym_size_int - None
            scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq);  eq - None
            _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3');  scalar_tensor_default - None
            submodule_0 - self.submodule_0
            submodule_1 - self.submodule_1
            cond: f32[3] - torch.ops.cond(True, submodule_0, submodule_1, [arg0_1]);  submodule_0 - submodule_1 - arg0_1 - None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        add_tensor: f32[3] - torch.ops.aten.add.Tensor(arg0_1, arg0_1);  arg0_1 - None
                return add_tensor

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3]):
                        sub_tensor: f32[3] - torch.ops.aten.sub.Tensor(arg0_1, arg0_1);  arg0_1 - None
                return sub_tensor

Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}

cond_branch_nonlocal_variables

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


def cond_branch_nonlocal_variables(x):
    """
    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
      - both branches must take the same args, which must also match the branch args passed to cond.
      - both branches must return a single tensor
      - returned tensor must have the same tensor metadata, e.g. shape and dtype
      - branch function can be free function, nested function, lambda, class methods
      - branch function can not have closure variables
      - no inplace mutations on inputs or global variables

    This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions.

    The code below will not work because capturing closure variables is not supported.
    ```
    my_tensor_var - x + 100
    my_primitive_var - 3.14

    def true_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y + my_tensor_var + my_primitive_var

    def false_fn(y):
        nonlocal my_tensor_var, my_primitive_var
        return y - my_tensor_var - my_primitive_var

    return cond(x.shape[0] > 5, true_fn, false_fn, [x])
    ```

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    my_tensor_var - x + 100
    my_primitive_var - 3.14

    def true_fn(x, y, z):
        return x + y + z

    def false_fn(x, y, z):
        return x - y - z

    return cond(
        x.shape[0] > 5,
        true_fn,
        false_fn,
        [x, my_tensor_var, torch.tensor(my_primitive_var)],
    )

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[6]):
            #
            sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
            eq - sym_size_int -- 6;  sym_size_int - None
            scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq);  eq - None
            _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 6');  scalar_tensor_default - None
            add_tensor: f32[6] - torch.ops.aten.add.Tensor(arg0_1, 100)
            _tensor_constant0: f32[] - self._tensor_constant0
            lift_fresh_copy_default: f32[] - torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 - None
            submodule_0 - self.submodule_0
            submodule_1 - self.submodule_1
            cond: f32[6] - torch.ops.cond(True, submodule_0, submodule_1, [arg0_1, add_tensor, lift_fresh_copy_default]);  submodule_0 - submodule_1 - arg0_1 - add_tensor - lift_fresh_copy_default - None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]):
                        add_tensor: f32[6] - torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 - arg1_1 - None
                add_tensor_1: f32[6] - torch.ops.aten.add.Tensor(add_tensor, arg2_1);  add_tensor - arg2_1 - None
                return add_tensor_1

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]):
                        sub_tensor: f32[6] - torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 - arg1_1 - None
                sub_tensor_1: f32[6] - torch.ops.aten.sub.Tensor(sub_tensor, arg2_1);  sub_tensor - arg2_1 - None
                return sub_tensor_1

Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}

cond_closed_over_variable

Note

Tags: python.closure, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


class CondClosedOverVariable(torch.nn.Module):
    """
    torch.cond() supports branches closed over arbitrary variables.
    """

    def forward(self, pred, x):
        def true_fn(val):
            return x * 2

        def false_fn(val):
            return x - 2

        return cond(pred, true_fn, false_fn, [x + 1])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: b8[], arg1_1: f32[3, 2]):
            #
            sym_size_int - torch.ops.aten.sym_size.int(arg1_1, 0)
            sym_size_int_1 - torch.ops.aten.sym_size.int(arg1_1, 1)
            eq - sym_size_int_1 -- 2;  sym_size_int_1 - None
            scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq);  eq - None
            _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[1] is specialized at 2');  scalar_tensor_default - None
            eq_1 - sym_size_int -- 3;  sym_size_int - None
            scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1);  eq_1 - None
            _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg1_1.shape[0] is specialized at 3');  scalar_tensor_default_1 - None
            add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg1_1, 1)
            submodule_0 - self.submodule_0
            submodule_1 - self.submodule_1
            cond: f32[3, 2] - torch.ops.cond(arg0_1, submodule_0, submodule_1, [add_tensor, arg1_1, arg1_1]);  arg0_1 - submodule_0 - submodule_1 - add_tensor - arg1_1 - None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2], arg2_1: f32[3, 2]):
                        mul_tensor: f32[3, 2] - torch.ops.aten.mul.Tensor(arg2_1, 2);  arg2_1 - None
                return mul_tensor

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[3, 2], arg1_1: f32[3, 2], arg2_1: f32[3, 2]):
                        sub_tensor: f32[3, 2] - torch.ops.aten.sub.Tensor(arg2_1, 2);  arg2_1 - None
                return sub_tensor

Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}

cond_operands

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from torch._export import dynamic_dim
from functorch.experimental.control_flow import cond

x - torch.randn(3, 2)
y - torch.ones(2)
dynamic_constraint - dynamic_dim(x, 0)

def cond_operands(x, y):
    """
    The operands passed to cond() must be:
      - a list of tensors
      - match arguments of `true_fn` and `false_fn`

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    def true_fn(x, y):
        return x + y

    def false_fn(x, y):
        return x - y

    return cond(x.shape[0] > 2, true_fn, false_fn, [x, y])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
            #
            sym_size_int: Sym(s0) - torch.ops.aten.sym_size.int(arg0_1, 0)
            sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1)
            sym_size_int_2 - torch.ops.aten.sym_size.int(arg1_1, 0)
            eq - sym_size_int_2 -- 2;  sym_size_int_2 - None
            scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq);  eq - None
            _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2');  scalar_tensor_default - None
            eq_1 - sym_size_int_1 -- 2;  sym_size_int_1 - None
            scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1);  eq_1 - None
            _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2');  scalar_tensor_default_1 - None
            sym_size: Sym(s0) - torch.ops.aten.sym_size.int(arg0_1, 0)
            gt: Sym(s0 > 2) - sym_size > 2;  sym_size - None
            submodule_0 - self.submodule_0
            submodule_1 - self.submodule_1
            cond: f32[s0, 2] - torch.ops.cond(gt, submodule_0, submodule_1, [arg0_1, arg1_1]);  gt - submodule_0 - submodule_1 - arg0_1 - arg1_1 - None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
                        add_tensor: f32[s0, 2] - torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 - arg1_1 - None
                return add_tensor

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]):
                        sub_tensor: f32[s0, 2] - torch.ops.aten.sub.Tensor(arg0_1, arg1_1);  arg0_1 - arg1_1 - None
                return sub_tensor

Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {s0: RangeConstraint(min_val-2, max_val-oo)}

cond_predicate

Note

Tags: torch.dynamic-shape, torch.cond

Support Level: SUPPORTED

Original source code:

import torch

from functorch.experimental.control_flow import cond


def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to cond() must be one of the following:
      - torch.Tensor with a single element
      - boolean expression

    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """

    pred - x.dim() > 2 and x.shape[2] > 10

    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

Result:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[6, 4, 3]):
            #
            sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
            sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1)
            sym_size_int_2 - torch.ops.aten.sym_size.int(arg0_1, 2)
            eq - sym_size_int_2 -- 3;  sym_size_int_2 - None
            scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq);  eq - None
            _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 3');  scalar_tensor_default - None
            eq_1 - sym_size_int_1 -- 4;  sym_size_int_1 - None
            scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1);  eq_1 - None
            _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 4');  scalar_tensor_default_1 - None
            eq_2 - sym_size_int -- 6;  sym_size_int - None
            scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2);  eq_2 - None
            _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 6');  scalar_tensor_default_2 - None
            submodule_0 - self.submodule_0
            submodule_1 - self.submodule_1
            cond: f32[6, 4, 3] - torch.ops.cond(False, submodule_0, submodule_1, [arg0_1]);  submodule_0 - submodule_1 - arg0_1 - None
            return (cond,)

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6, 4, 3]):
                        cos_default: f32[6, 4, 3] - torch.ops.aten.cos.default(arg0_1);  arg0_1 - None
                return cos_default

        class GraphModule(torch.nn.Module):
            def forward(self, arg0_1: f32[6, 4, 3]):
                        sin_default: f32[6, 4, 3] - torch.ops.aten.sin.default(arg0_1);  arg0_1 - None
                return sin_default

Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources