torch.map¶
dynamic_shape_map¶
Original source code:
import torch
from functorch.experimental.control_flow import map
def dynamic_shape_map(xs, y):
"""
functorch map() maps a function over the first tensor dimension.
"""
def body(x, y):
return x + y
return map(body, xs, y)
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
#
sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1)
sym_size_int_2 - torch.ops.aten.sym_size.int(arg1_1, 0)
eq - sym_size_int_2 -- 2; sym_size_int_2 - None
scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None
_assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default - None
eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None
scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None
_assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None
eq_2 - sym_size_int -- 3; sym_size_int - None
scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None
_assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 - None
submodule_0 - self.submodule_0
map_impl - torch.ops.map_impl(submodule_0, 1, arg0_1, arg1_1); submodule_0 - arg0_1 - arg1_1 - None
getitem: f32[3, 2] - map_impl[0]; map_impl - None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]):
add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None
return [add_tensor]
Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['getitem'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}