Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion onnxscript/_internal/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,25 @@ def eval_op(
args: Sequence[ExtendedModeValue],
kwargs: Mapping[str, ExtendedModeValue],
):
from onnxscript._internal.typed_sequence import (
_TypedSequence, # pylint: disable=import-outside-toplevel
)

op_signature = op.op_signature
assert op_signature is not None, f"Op {op.name} has no signature."
attributes = _unwrap_tensors_in_kwargs(kwargs)
attributes, closure = self._adapt_attributes(op_signature, attributes)
inputs = self._adapt_inputs(op_signature, args)
outputs = self._eval(op.op_schema, inputs, attributes, closure)
return self._adapt_outputs(outputs)
result = self._adapt_outputs(outputs)

# Handle SequenceEmpty: wrap result with _TypedSequence to preserve dtype
# for empty sequences where type cannot be inferred from elements.
if op.name == "SequenceEmpty":
dtype = attributes.get("dtype") or onnx.TensorProto.FLOAT
return _TypedSequence(dtype, result)

return result

def eval_function(
self,
Expand Down Expand Up @@ -364,8 +376,15 @@ def compute_num_outputs(

def _onnxscript_to_numpy_value(v):
"""Converts an onnxscript encoding of an ONNX value into the numpy encoding used by runtimes."""
from onnxscript._internal.typed_sequence import (
_TypedSequence, # pylint: disable=import-outside-toplevel
)

if isinstance(v, tensor.Tensor):
return v.value
# Handle _TypedSequence BEFORE generic list check to preserve dtype info
if isinstance(v, _TypedSequence):
return _TypedSequence(v.onnx_dtype, [_onnxscript_to_numpy_value(x) for x in v])
if isinstance(v, list):
return [_onnxscript_to_numpy_value(x) for x in v]
if isinstance(v, tuple):
Expand Down
33 changes: 33 additions & 0 deletions onnxscript/_internal/typed_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from typing import Any, Iterable


class _TypedSequence(list):
"""A list subclass that preserves ONNX dtype information.

ONNX sequences need to carry element type information, but Python lists
don't have this natively. This class wraps a list with dtype info, which is
especially important for empty sequences where type cannot be inferred.

Example::
seq = _TypedSequence(onnx.TensorProto.FLOAT16)
# seq is now an empty list that knows it should contain FLOAT16 tensors
"""

__slots__ = ("_onnx_dtype",)

def __init__(self, dtype: int, iterable: Iterable[Any] | None = None) -> None:
super().__init__(iterable if iterable is not None else ())
self._onnx_dtype = dtype

@property
def onnx_dtype(self) -> int:
"""Return the ONNX TensorProto data type for sequence elements."""
return self._onnx_dtype

def __repr__(self) -> str:
return f"_TypedSequence(dtype={self._onnx_dtype}, {list(self)})"
8 changes: 8 additions & 0 deletions onnxscript/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import onnx_ir as ir

from onnxscript import tensor
from onnxscript._internal.typed_sequence import _TypedSequence


def external_tensor(
Expand Down Expand Up @@ -72,6 +73,13 @@ def value_to_type_proto(val):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.INT32, []) # noqa: TID251
if isinstance(val, (float, np.float32)):
return onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, []) # noqa: TID251

# Handle _TypedSequence BEFORE generic list check
if isinstance(val, _TypedSequence):
return onnx.helper.make_sequence_type_proto( # noqa: TID251
onnx.helper.make_tensor_type_proto(val.onnx_dtype, None) # noqa: TID251
)

if isinstance(val, list):
if len(val) > 0:
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0])) # noqa: TID251
Expand Down
44 changes: 44 additions & 0 deletions tests/eager_mode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import unittest

import numpy as np
import onnx
import parameterized

import onnxscript.evaluator
Expand Down Expand Up @@ -51,6 +52,49 @@ def Concat(seq):
output2 = Concat([os_tensor, os_tensor])
self.assertIsInstance(output2, onnxscript.tensor.Tensor)

def test_sequence_empty_preserves_dtype(self):
"""Regression test for SequenceEmpty dtype parameter.

Verify that SequenceEmpty correctly preserves the dtype attribute
so that SequenceInsert doesn't fail with type mismatch errors.
"""

@script()
def test_float16(img_in):
seq = op.SequenceEmpty(dtype=onnx.TensorProto.FLOAT16)
seq = op.SequenceInsert(seq, img_in)
return seq

@script()
def test_double(img_in):
seq = op.SequenceEmpty(dtype=onnx.TensorProto.DOUBLE)
seq = op.SequenceInsert(seq, img_in)
return seq

@script()
def test_int64(img_in):
seq = op.SequenceEmpty(dtype=onnx.TensorProto.INT64)
seq = op.SequenceInsert(seq, img_in)
return seq

# Test FLOAT16
img_float16 = np.random.randn(2, 3).astype(np.float16)
res = test_float16(img_float16)
self.assertEqual(len(res), 1)
np.testing.assert_array_equal(res[0], img_float16)

# Test DOUBLE
img_double = np.random.randn(2, 3).astype(np.float64)
res = test_double(img_double)
self.assertEqual(len(res), 1)
np.testing.assert_array_equal(res[0], img_double)

# Test INT64
img_int64 = np.array([[1, 2], [3, 4]], dtype=np.int64)
res = test_int64(img_int64)
self.assertEqual(len(res), 1)
np.testing.assert_array_equal(res[0], img_int64)


@script()
def add_with_alpha(this, other, alpha: float = 1.0):
Expand Down