python.assert¶
dynamic_shape_assert¶
Original source code:
import torch
def dynamic_shape_assert(x):
"""
A basic usage of python assertion.
"""
# assertion with error message
assert x.shape[0] > 2, f"{x.shape[0]} is greater than 2"
# assertion without error message
assert x.shape[0] > 1
return x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1)
eq - sym_size_int_1 -- 2; sym_size_int_1 - None
scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None
_assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None
eq_1 - sym_size_int -- 3; sym_size_int - None
scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None
_assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None
return (arg0_1,)
Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['arg0_1'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}
list_contains¶
Original source code:
import torch
def list_contains(x):
"""
List containment relation can be checked on a dynamic shape or constants.
"""
assert x.size(-1) in [6, 2]
assert x.size(0) not in [4, 5, 6]
assert "monkey" not in ["cow", "pig"]
return x + x
Result:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, arg0_1: f32[3, 2]):
#
sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0)
sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1)
eq - sym_size_int_1 -- 2; sym_size_int_1 - None
scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None
_assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None
eq_1 - sym_size_int -- 3; sym_size_int - None
scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None
_assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None
add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 - None
return (add_tensor,)
Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['add_tensor'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None)
Symbol to range: {}