Skip to content

[PyTorch] Python DType enum#3039

Open
vthumbe1503 wants to merge 34 commits into
NVIDIA:mainfrom
vthumbe1503:te_dtype
Open

[PyTorch] Python DType enum#3039
vthumbe1503 wants to merge 34 commits into
NVIDIA:mainfrom
vthumbe1503:te_dtype

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 22, 2026

Description

Replace the pybind tex.DType with a canonical Python DType IntEnum throughout transformer_engine.pytorch. For backward compat, cpp->python and python->cpp DType object conversions are cached at the pybind boundaries to reduce CPU overheads.

Motivation

  • CPU overheads: tex.DType is a pybind enum, so every access/compare/convert in Python crosses into C-extension code.
  • torch.compile: tex.DType won't work with torch.compile — TorchDynamo doesn't understand pybind enums, so it graph-breaks (or fails to trace) when one flows through a compiled region.
  • Checkpointing: tex.DType lives in tensor/quantizer state and lands in checkpoints; pickling a pybind enum is fragile and awkward to allow-list vs. a stdlib python enum.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Canonical DType: Add a pure-Python DType(IntEnum) in constants.py as the single source of truth. Its members are defined from the C++ enum values, and an import-time assert verifies it stays in sync with the pybinded enum tex.DType
  • Migration: Repoint TE_DType maps and move pytorch modules, examples, benchmarks, and tests off raw tex.DType onto constants.DType.
  • Backward compatibility: Add DTypeSupported = Union[DType, tex.DType]; tex.DType is still accepted at constructor boundaries and stays allow-listed for loading old checkpoints.
  • Python → C++: Register a cached pybind implicit conversion (dtype_pybind_conversion.h) so a constants.DType is auto-accepted wherever a C++ tex.DType is expected.
  • C++ → Python: Add cached MakePythonDType (csrc/common.*) and use it at quantizer/quantizedtensor construction.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits May 22, 2026 21:50
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title initial prototype TE_DType in python May 22, 2026
// when this runs, so the GIL is held and Python imports are legal.
static pybind11::object te_dtype_cls =
pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType");
return te_dtype_cls(static_cast<int>(dtype));
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Find a way to bind C++ and python Dtype through pybind cast mechanism

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done for Python. -> C++

For C++ to Python. --> Cant avoid this.

Comment thread transformer_engine/pytorch/__init__.py Outdated
# pybind11 enum used as Quantizer.dtype
tex.DType,
# Python IntEnum used as Quantizer.dtype
TE_DType,
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save/load backward compatibilty should be there

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

vthumbe1503 and others added 14 commits May 31, 2026 19:45
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title TE_DType in python [PyTorch] Python DType enum Jun 1, 2026
vthumbe1503 and others added 2 commits June 1, 2026 08:27
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 1, 2026 08:32
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 1, 2026

Greptile Summary

This PR replaces the pybind11 tex.DType enum with a pure-Python DType(IntEnum) defined in constants.py as the canonical dtype tag throughout transformer_engine.pytorch. The migration adds cached C++↔Python conversion helpers (MakePythonDType, dtype_pybind_conversion.h) and patches __eq__/__ne__ on both enum types to maintain cross-type equality for backward compatibility.

  • New DType IntEnum: defined in constants.py mirroring all tex.DType values; import-time assert keeps the two in sync; DType.cast() normalizes tex.DTypeDType at every constructor boundary.
  • C++ → Python path: MakePythonDType() in common.cpp builds a per-value static cache of constants.DType members and replaces all py::cast(this->dtype) calls in quantizer/tensor creation.
  • Python → C++ path: a custom type_caster<transformer_engine::DType> in dtype_pybind_conversion.h transparently accepts constants.DType, plain int, or tex.DType at every pybind11 boundary.

Confidence Score: 5/5

This PR is safe to merge. The migration is mechanically consistent across all 66 files, backward compatibility is preserved at every public boundary, and the cross-type equality patches restore the hash invariant correctly.

The core implementation is sound: DType.cast() normalizes inputs at every constructor site, MakePythonDType covers all C++→Python tensor-construction paths, and the patched eq/ne on both enums maintain the hash-equality contract. The two findings are robustness concerns — the ignored convert flag in the type caster and the static pybind11::object array lifetime — neither of which affects correctness in the standard single-interpreter use case.

transformer_engine/common/util/dtype_pybind_conversion.h (convert flag) and transformer_engine/pytorch/csrc/common.cpp (static pybind11::object lifetime) are worth a second look before the library is used in embedded-Python or multi-interpreter scenarios.

Important Files Changed

Filename Overview
transformer_engine/pytorch/constants.py Introduces the canonical DType(IntEnum) with cross-type eq/ne/hash, a cast() classmethod, and an import-time sync assert; TE_DType_To_Torch is cleanly derived by reversing TE_DType.
transformer_engine/common/util/dtype_pybind_conversion.h New custom type caster for Python→C++ DType conversion; correctly handles IntEnum/int/tex.DType inputs, but the PyLong_Check path ignores the convert flag, deviating from pybind11 strict-mode convention.
transformer_engine/pytorch/csrc/common.cpp Adds MakePythonDType() with a C++11 magic-static cache of constants.DType members; includes the NVTE_CHECK guard on null members, but the static pybind11::object array carries a teardown-crash risk in embedded-Python scenarios.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Patches eq/ne on tex.DType to compare by integer value against both tex.DType and Python int/IntEnum, enabling cross-type equality without breaking hash invariants.
transformer_engine/pytorch/csrc/quantizer.cpp All py::cast(this->dtype) calls replaced with MakePythonDType(this->dtype) at every tensor construction and convert_and_update_tensor site — coverage appears complete.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Accepts Union[DType, tex.DType] in new and normalizes immediately via DType.cast(fp8_dtype), ensuring even old-checkpoint deserialization paths store canonical DType.
transformer_engine/pytorch/init.py Exports DType at the package level and adds it to the pickle allow-list alongside tex.DType for backward-compatible checkpoint loading.
transformer_engine/pytorch/quantization.py Recipe state dataclass dtype fields migrated from tex.DType to DType; get_fp8_te_dtype / get_fp4_te_dtype / get_fp8_max return types updated consistently.

Sequence Diagram

sequenceDiagram
    participant PY as Python caller
    participant DT as constants.DType (IntEnum)
    participant TC as type_caster (dtype_pybind_conversion.h)
    participant CPP as C++ DType
    participant MPD as MakePythonDType (common.cpp)
    participant TEX as tex.DType (pybind11)

    Note over PY,TEX: Python → C++ (input path)
    PY->>TC: pass DType / tex.DType / int
    TC->>TC: PyLong_Check? → cast int value
    TC->>CPP: transformer_engine::DType

    Note over PY,TEX: C++ → Python (output path, quantizer construction)
    CPP->>MPD: MakePythonDType(dtype)
    MPD->>MPD: lookup in static cache (built once via module_::import)
    MPD->>DT: return cached DType member

    Note over PY,TEX: Equality (cross-type comparison)
    PY->>DT: DType.__eq__(tex.DType)
    DT->>DT: "int(self) == int(other)"
    PY->>TEX: tex.DType.__eq__(DType) (patched in pybind.cpp)
    TEX->>TEX: "int(self) == int(other)"
Loading

Reviews (10): Last reviewed commit: "Merge branch 'main' into te_dtype" | Re-trigger Greptile

Comment on lines +44 to +49
# Fail fast at import time if a new enumerator is added
# on the C++ side without being mirrored above.
assert {m.name for m in DType} == set(tex.DType.__members__), (
"DType is out of sync with transformer_engine_torch.DType; "
"add the new pybind enumerator to DType in constants.py."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Import-time sync check can be silently skipped

Python's -O (optimize) flag strips all assert statements, so this import-time guard that verifies DType is in sync with tex.DType will never run in optimized/production builds. A build where a new C++ enumerator was added without updating DType would import without error and produce silent mismatches downstream. Replace with an explicit if ... raise.

Comment on lines +81 to +86
/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
* Allows a Python object of type ``transformer_engine.pytorch.constants.DType``
* to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected.
* pybind-bound ``transformer_engine::DType`` argument is expected.
* Must be called after the pybind ``DType`` enum has been registered.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicate sentence in the docstring — the line "pybind-bound transformer_engine::DType argument is expected." appears twice.

Suggested change
/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
* Allows a Python object of type ``transformer_engine.pytorch.constants.DType``
* to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected.
* pybind-bound ``transformer_engine::DType`` argument is expected.
* Must be called after the pybind ``DType`` enum has been registered.
*/
/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
* Allows a Python object of type ``transformer_engine.pytorch.constants.DType``
* to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected.
* Must be called after the pybind ``DType`` enum has been registered.
*/

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +14 to +41
class DType(enum.IntEnum):
"""Python mirror of ``transformer_engine_torch.DType`` (pybind11 enum).
Members are constructed manually from the underlying pybind enum so
that this class is the single source of truth for dtype tags used
across ``transformer_engine.pytorch``.
"""

kByte = int(tex.DType.kByte)
kInt32 = int(tex.DType.kInt32)
kFloat32 = int(tex.DType.kFloat32)
kFloat16 = int(tex.DType.kFloat16)
kBFloat16 = int(tex.DType.kBFloat16)
kFloat8E4M3 = int(tex.DType.kFloat8E4M3)
kFloat8E5M2 = int(tex.DType.kFloat8E5M2)
kFloat4E2M1 = int(tex.DType.kFloat4E2M1)

@classmethod
def cast(cls, dtype: "DTypeSupported") -> "DType":
"""Normalize any ``DTypeSupported`` value to the canonical ``DType`` ``IntEnum``.
``DType`` is the canonical dtype tag used internally throughout
``transformer_engine.pytorch``, and is what this function always outputs.
The pybind ``transformer_engine_torch.DType`` enum is an additional type
accepted as input (for backward compatibility), which this function maps
to the matching ``DType`` member so stored attributes are always ``DType``.
"""
if isinstance(dtype, cls):
return dtype
return cls(int(dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Equality comparison between constants.DType and tex.DType silently returns False

tex.DType is a pybind11 enum without .arithmetic(), so its __eq__ only compares with the same C-extension type. constants.DType is an IntEnum (a Python int subclass), so int.__eq__ is used on the left side — but CPython's int.__eq__ returns NotImplemented for non-PyLong objects, and pybind11's __eq__ also returns NotImplemented for a non-tex.DType right-hand side. Python falls back to identity comparison, yielding False. Existing user code like if quantizer.dtype == tex.DType.kFloat8E4M3: now silently evaluates to False even though the types are equivalent. The PR's documented backward-compat guarantee covers constructors and checkpoints but not equality comparisons, leaving this as an undocumented silent break.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this now. We have equality comparators defined for both DType and tex.DType to handle this

@vthumbe1503 vthumbe1503 requested a review from ptrendx June 1, 2026 18:13
Comment thread benchmarks/benchmark_rht_cast.py Outdated

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import constants
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering how frequent it is, shouldn't we just expose the DType in the top module (transformer_engine.pytorch)?

return cached_dtype_object;
}

/*! @brief Register the Python -> C++ ``DType`` implicit conversion.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if that is accurate to be honest. My understanding of this function is that it converts the Python object of constants.DType to tex.DType, and only then the real conversion to the C++ type happens - this would actually increase the overhead I think. I believe the right approach is to use the custom type_caster from pybind to get the DType from either type.

Comment thread transformer_engine/pytorch/attention/dot_product_attention/utils.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/cast.cpp Outdated
return transformer_engine::DType::kFloat8E5M2;
}

pybind11::object MakePythonDType(transformer_engine::DType dtype) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am slightly worried about the thread safety of this function when there is no GIL.

Copy link
Copy Markdown
Collaborator Author

@vthumbe1503 vthumbe1503 Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it thread safe now. Entire array is initialized once as a static const

Comment thread transformer_engine/pytorch/csrc/common.h Outdated
Comment thread transformer_engine/pytorch/__init__.py Outdated
Float8BlockQuantizer,
# pybind11 enum used as Quantizer.dtype
# Python IntEnum used as Quantizer.dtype.
constants.DType,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to add it here even if it is a regular Python object?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes for every python object that is used as an argument for checkpoint loading function needs its class allow listed

Comment thread transformer_engine/pytorch/constants.py Outdated
Comment thread transformer_engine/pytorch/constants.py Outdated
Comment on lines +52 to +54
# tex.DType is the pybind enum kept for backward compatibility.
# in the constructors for QuantizedTensors and Quantizers.
DTypeSupported = Union[DType, tex.DType]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, why can't you just use this union type directly in the cast function declaration rather than having this indirection?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is reused in multiple places including Quantizer and QuantizedTensor constructors. So I defined it once here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now using Union directly

# ``transformer_engine.h``). Use the bracket syntax ``TE_DType[torch_dtype]``
# to resolve a ``torch.dtype`` to its matching ``DType`` member.
# Used for passing dtypes into cuda extension.
TE_DType = {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we sometimes need int64

Comment thread transformer_engine/pytorch/constants.py Outdated
vthumbe1503 and others added 9 commits June 2, 2026 06:01
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/float8_tensor.py Outdated
Comment thread transformer_engine/pytorch/constants.py Outdated
# on the C++ side without being mirrored above.
assert {m.name for m in DType} == set(tex.DType.__members__), (
"DType is out of sync with transformer_engine_torch.DType; "
"add the new pybind enumerator to DType in constants.py."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good comment for TE devs, but if someone loads newer/older c++ library version of TE with different dtypes this error may be confusing.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have made the comment better now.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
}
return members;
}();
return cache[idx];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 MakePythonDType silently returns a null object on cache miss

cache[idx] is a default-constructed pybind11::object (null internal pointer) for any DType index not covered by the Python enum iteration — which would happen if the sync assertion in constants.py is stripped by Python's -O flag and a new C++ enum value is added without a matching DType member. Returning that null object propagates to callers such as kwargs["fp8_dtype"] = MakePythonDType(dtype), causing a null-pointer dereference or an opaque crash when the Python constructor receives the argument. The equivalent cache in dtype_pybind_conversion.h already guards against this with an explicit null check (if (cached == nullptr) { return handle(); }); the same guard belongs here.

vthumbe1503 and others added 5 commits June 2, 2026 19:21
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch amd64 arm64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants