torch.map ============= dynamic_shape_map ^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.map `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import map def dynamic_shape_map(xs, y): """ functorch map() maps a function over the first tensor dimension. """ def body(x, y): return x + y return map(body, xs, y) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg1_1, 0) eq - sym_size_int_2 -- 2; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None eq_2 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 - None submodule_0 - self.submodule_0 map_impl - torch.ops.map_impl(submodule_0, 1, arg0_1, arg1_1); submodule_0 - arg0_1 - arg1_1 - None getitem: f32[3, 2] - map_impl[0]; map_impl - None return (getitem,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]): add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None return [add_tensor] Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['getitem'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {}