torch.escape-hatch¶
assume_constant_result¶
Original source code:
import torch
import torch._dynamo as torchdynamo
class AssumeConstantResult(torch.nn.Module):
"""
Applying `assume_constant_result` decorator to burn make non-tracable code as constant.
"""
def __init__(self):
super().__init__()
@torchdynamo.assume_constant_result
def get_item(self, y):
return y.int().item()
def forward(self, x, y):
return x[: self.get_item(y)]
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]):
#
sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1)
eq - sym_size_int_1 -- 2; sym_size_int_1 - None
scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None
_assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None
eq_1 - sym_size_int -- 3; sym_size_int - None
scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None
_assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None
slice_tensor: f32[3, 2] - torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4); arg0_1 - None
return (slice_tensor,)
Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['slice_tensor'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}