diff --git a/onnxscript/_internal/evaluator.py b/onnxscript/_internal/evaluator.py index e290eaf7ac..e7c022fd59 100644 --- a/onnxscript/_internal/evaluator.py +++ b/onnxscript/_internal/evaluator.py @@ -251,13 +251,25 @@ def eval_op( args: Sequence[ExtendedModeValue], kwargs: Mapping[str, ExtendedModeValue], ): + from onnxscript._internal.typed_sequence import ( + _TypedSequence, # pylint: disable=import-outside-toplevel + ) + op_signature = op.op_signature assert op_signature is not None, f"Op {op.name} has no signature." attributes = _unwrap_tensors_in_kwargs(kwargs) attributes, closure = self._adapt_attributes(op_signature, attributes) inputs = self._adapt_inputs(op_signature, args) outputs = self._eval(op.op_schema, inputs, attributes, closure) - return self._adapt_outputs(outputs) + result = self._adapt_outputs(outputs) + + # Handle SequenceEmpty: wrap result with _TypedSequence to preserve dtype + # for empty sequences where type cannot be inferred from elements. + if op.name == "SequenceEmpty": + dtype = attributes.get("dtype") or onnx.TensorProto.FLOAT + return _TypedSequence(dtype, result) + + return result def eval_function( self, @@ -364,8 +376,15 @@ def compute_num_outputs( def _onnxscript_to_numpy_value(v): """Converts an onnxscript encoding of an ONNX value into the numpy encoding used by runtimes.""" + from onnxscript._internal.typed_sequence import ( + _TypedSequence, # pylint: disable=import-outside-toplevel + ) + if isinstance(v, tensor.Tensor): return v.value + # Handle _TypedSequence BEFORE generic list check to preserve dtype info + if isinstance(v, _TypedSequence): + return _TypedSequence(v.onnx_dtype, [_onnxscript_to_numpy_value(x) for x in v]) if isinstance(v, list): return [_onnxscript_to_numpy_value(x) for x in v] if isinstance(v, tuple): diff --git a/onnxscript/_internal/typed_sequence.py b/onnxscript/_internal/typed_sequence.py new file mode 100644 index 0000000000..90824dbf88 --- /dev/null +++ b/onnxscript/_internal/typed_sequence.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Iterable + + +class _TypedSequence(list): + """A list subclass that preserves ONNX dtype information. + + ONNX sequences need to carry element type information, but Python lists + don't have this natively. This class wraps a list with dtype info, which is + especially important for empty sequences where type cannot be inferred. + + Example:: + seq = _TypedSequence(onnx.TensorProto.FLOAT16) + # seq is now an empty list that knows it should contain FLOAT16 tensors + """ + + __slots__ = ("_onnx_dtype",) + + def __init__(self, dtype: int, iterable: Iterable[Any] | None = None) -> None: + super().__init__(iterable if iterable is not None else ()) + self._onnx_dtype = dtype + + @property + def onnx_dtype(self) -> int: + """Return the ONNX TensorProto data type for sequence elements.""" + return self._onnx_dtype + + def __repr__(self) -> str: + return f"_TypedSequence(dtype={self._onnx_dtype}, {list(self)})" diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index c5e05b3d92..0d7bbf63d3 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -10,6 +10,7 @@ import onnx_ir as ir from onnxscript import tensor +from onnxscript._internal.typed_sequence import _TypedSequence def external_tensor( @@ -72,6 +73,13 @@ def value_to_type_proto(val): return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251 if isinstance(val, (float, np.float32)): return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251 + + # Handle _TypedSequence BEFORE generic list check + if isinstance(val, _TypedSequence): + return onnx.helper.make_sequence_type_proto( # noqa: TID251 + onnx.helper.make_tensor_type_proto(val.onnx_dtype, None) # noqa: TID251 + ) + if isinstance(val, list): if len(val) > 0: return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251 diff --git a/tests/eager_mode_test.py b/tests/eager_mode_test.py index e4cb0ab313..a1031d6731 100644 --- a/tests/eager_mode_test.py +++ b/tests/eager_mode_test.py @@ -4,6 +4,7 @@ import unittest import numpy as np +import onnx import parameterized import onnxscript.evaluator @@ -51,6 +52,49 @@ def Concat(seq): output2 = Concat([os_tensor, os_tensor]) self.assertIsInstance(output2, onnxscript.tensor.Tensor) + def test_sequence_empty_preserves_dtype(self): + """Regression test for SequenceEmpty dtype parameter. + + Verify that SequenceEmpty correctly preserves the dtype attribute + so that SequenceInsert doesn't fail with type mismatch errors. + """ + + @script() + def test_float16(img_in): + seq = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT16) + seq = op.SequenceInsert(seq, img_in) + return seq + + @script() + def test_double(img_in): + seq = op.SequenceEmpty(dtype=onnx.TensorProto.DOUBLE) + seq = op.SequenceInsert(seq, img_in) + return seq + + @script() + def test_int64(img_in): + seq = op.SequenceEmpty(dtype=onnx.TensorProto.INT64) + seq = op.SequenceInsert(seq, img_in) + return seq + + # Test FLOAT16 + img_float16 = np.random.randn(2, 3).astype(np.float16) + res = test_float16(img_float16) + self.assertEqual(len(res), 1) + np.testing.assert_array_equal(res[0], img_float16) + + # Test DOUBLE + img_double = np.random.randn(2, 3).astype(np.float64) + res = test_double(img_double) + self.assertEqual(len(res), 1) + np.testing.assert_array_equal(res[0], img_double) + + # Test INT64 + img_int64 = np.array([[1, 2], [3, 4]], dtype=np.int64) + res = test_int64(img_int64) + self.assertEqual(len(res), 1) + np.testing.assert_array_equal(res[0], img_int64) + @script() def add_with_alpha(this, other, alpha: float = 1.0):