[PyTorch] Python DType enum#3039
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
| // when this runs, so the GIL is held and Python imports are legal. | ||
| static pybind11::object te_dtype_cls = | ||
| pybind11::module_::import("transformer_engine.pytorch.constants").attr("TE_DType"); | ||
| return te_dtype_cls(static_cast<int>(dtype)); |
There was a problem hiding this comment.
Find a way to bind C++ and python Dtype through pybind cast mechanism
There was a problem hiding this comment.
This is done for Python. -> C++
For C++ to Python. --> Cant avoid this.
| # pybind11 enum used as Quantizer.dtype | ||
| tex.DType, | ||
| # Python IntEnum used as Quantizer.dtype | ||
| TE_DType, |
There was a problem hiding this comment.
save/load backward compatibilty should be there
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Greptile SummaryThis PR replaces the pybind11
Confidence Score: 5/5This PR is safe to merge. The migration is mechanically consistent across all 66 files, backward compatibility is preserved at every public boundary, and the cross-type equality patches restore the hash invariant correctly. The core implementation is sound: DType.cast() normalizes inputs at every constructor site, MakePythonDType covers all C++→Python tensor-construction paths, and the patched eq/ne on both enums maintain the hash-equality contract. The two findings are robustness concerns — the ignored convert flag in the type caster and the static pybind11::object array lifetime — neither of which affects correctness in the standard single-interpreter use case. transformer_engine/common/util/dtype_pybind_conversion.h (convert flag) and transformer_engine/pytorch/csrc/common.cpp (static pybind11::object lifetime) are worth a second look before the library is used in embedded-Python or multi-interpreter scenarios. Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python caller
participant DT as constants.DType (IntEnum)
participant TC as type_caster (dtype_pybind_conversion.h)
participant CPP as C++ DType
participant MPD as MakePythonDType (common.cpp)
participant TEX as tex.DType (pybind11)
Note over PY,TEX: Python → C++ (input path)
PY->>TC: pass DType / tex.DType / int
TC->>TC: PyLong_Check? → cast int value
TC->>CPP: transformer_engine::DType
Note over PY,TEX: C++ → Python (output path, quantizer construction)
CPP->>MPD: MakePythonDType(dtype)
MPD->>MPD: lookup in static cache (built once via module_::import)
MPD->>DT: return cached DType member
Note over PY,TEX: Equality (cross-type comparison)
PY->>DT: DType.__eq__(tex.DType)
DT->>DT: "int(self) == int(other)"
PY->>TEX: tex.DType.__eq__(DType) (patched in pybind.cpp)
TEX->>TEX: "int(self) == int(other)"
Reviews (10): Last reviewed commit: "Merge branch 'main' into te_dtype" | Re-trigger Greptile |
| # Fail fast at import time if a new enumerator is added | ||
| # on the C++ side without being mirrored above. | ||
| assert {m.name for m in DType} == set(tex.DType.__members__), ( | ||
| "DType is out of sync with transformer_engine_torch.DType; " | ||
| "add the new pybind enumerator to DType in constants.py." | ||
| ) |
There was a problem hiding this comment.
Import-time sync check can be silently skipped
Python's -O (optimize) flag strips all assert statements, so this import-time guard that verifies DType is in sync with tex.DType will never run in optimized/production builds. A build where a new C++ enumerator was added without updating DType would import without error and produce silent mismatches downstream. Replace with an explicit if ... raise.
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. | ||
| * Allows a Python object of type ``transformer_engine.pytorch.constants.DType`` | ||
| * to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. | ||
| * pybind-bound ``transformer_engine::DType`` argument is expected. | ||
| * Must be called after the pybind ``DType`` enum has been registered. | ||
| */ |
There was a problem hiding this comment.
Duplicate sentence in the docstring — the line "pybind-bound
transformer_engine::DType argument is expected." appears twice.
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. | |
| * Allows a Python object of type ``transformer_engine.pytorch.constants.DType`` | |
| * to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. | |
| * pybind-bound ``transformer_engine::DType`` argument is expected. | |
| * Must be called after the pybind ``DType`` enum has been registered. | |
| */ | |
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. | |
| * Allows a Python object of type ``transformer_engine.pytorch.constants.DType`` | |
| * to be passed wherever a pybind-bound ``transformer_engine::DType`` argument is expected. | |
| * Must be called after the pybind ``DType`` enum has been registered. | |
| */ |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| class DType(enum.IntEnum): | ||
| """Python mirror of ``transformer_engine_torch.DType`` (pybind11 enum). | ||
| Members are constructed manually from the underlying pybind enum so | ||
| that this class is the single source of truth for dtype tags used | ||
| across ``transformer_engine.pytorch``. | ||
| """ | ||
|
|
||
| kByte = int(tex.DType.kByte) | ||
| kInt32 = int(tex.DType.kInt32) | ||
| kFloat32 = int(tex.DType.kFloat32) | ||
| kFloat16 = int(tex.DType.kFloat16) | ||
| kBFloat16 = int(tex.DType.kBFloat16) | ||
| kFloat8E4M3 = int(tex.DType.kFloat8E4M3) | ||
| kFloat8E5M2 = int(tex.DType.kFloat8E5M2) | ||
| kFloat4E2M1 = int(tex.DType.kFloat4E2M1) | ||
|
|
||
| @classmethod | ||
| def cast(cls, dtype: "DTypeSupported") -> "DType": | ||
| """Normalize any ``DTypeSupported`` value to the canonical ``DType`` ``IntEnum``. | ||
| ``DType`` is the canonical dtype tag used internally throughout | ||
| ``transformer_engine.pytorch``, and is what this function always outputs. | ||
| The pybind ``transformer_engine_torch.DType`` enum is an additional type | ||
| accepted as input (for backward compatibility), which this function maps | ||
| to the matching ``DType`` member so stored attributes are always ``DType``. | ||
| """ | ||
| if isinstance(dtype, cls): | ||
| return dtype | ||
| return cls(int(dtype)) |
There was a problem hiding this comment.
Equality comparison between
constants.DType and tex.DType silently returns False
tex.DType is a pybind11 enum without .arithmetic(), so its __eq__ only compares with the same C-extension type. constants.DType is an IntEnum (a Python int subclass), so int.__eq__ is used on the left side — but CPython's int.__eq__ returns NotImplemented for non-PyLong objects, and pybind11's __eq__ also returns NotImplemented for a non-tex.DType right-hand side. Python falls back to identity comparison, yielding False. Existing user code like if quantizer.dtype == tex.DType.kFloat8E4M3: now silently evaluates to False even though the types are equivalent. The PR's documented backward-compat guarantee covers constructors and checkpoints but not equality comparisons, leaving this as an undocumented silent break.
There was a problem hiding this comment.
Fixed this now. We have equality comparators defined for both DType and tex.DType to handle this
|
|
||
| import transformer_engine.pytorch as te | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine.pytorch import constants |
There was a problem hiding this comment.
Considering how frequent it is, shouldn't we just expose the DType in the top module (transformer_engine.pytorch)?
| return cached_dtype_object; | ||
| } | ||
|
|
||
| /*! @brief Register the Python -> C++ ``DType`` implicit conversion. |
There was a problem hiding this comment.
Not sure if that is accurate to be honest. My understanding of this function is that it converts the Python object of constants.DType to tex.DType, and only then the real conversion to the C++ type happens - this would actually increase the overhead I think. I believe the right approach is to use the custom type_caster from pybind to get the DType from either type.
| return transformer_engine::DType::kFloat8E5M2; | ||
| } | ||
|
|
||
| pybind11::object MakePythonDType(transformer_engine::DType dtype) { |
There was a problem hiding this comment.
I am slightly worried about the thread safety of this function when there is no GIL.
There was a problem hiding this comment.
Made it thread safe now. Entire array is initialized once as a static const
| Float8BlockQuantizer, | ||
| # pybind11 enum used as Quantizer.dtype | ||
| # Python IntEnum used as Quantizer.dtype. | ||
| constants.DType, |
There was a problem hiding this comment.
Do we still need to add it here even if it is a regular Python object?
There was a problem hiding this comment.
Yes for every python object that is used as an argument for checkpoint loading function needs its class allow listed
| # tex.DType is the pybind enum kept for backward compatibility. | ||
| # in the constructors for QuantizedTensors and Quantizers. | ||
| DTypeSupported = Union[DType, tex.DType] |
There was a problem hiding this comment.
Hmmm, why can't you just use this union type directly in the cast function declaration rather than having this indirection?
There was a problem hiding this comment.
This is reused in multiple places including Quantizer and QuantizedTensor constructors. So I defined it once here.
There was a problem hiding this comment.
Now using Union directly
| # ``transformer_engine.h``). Use the bracket syntax ``TE_DType[torch_dtype]`` | ||
| # to resolve a ``torch.dtype`` to its matching ``DType`` member. | ||
| # Used for passing dtypes into cuda extension. | ||
| TE_DType = { |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
| # on the C++ side without being mirrored above. | ||
| assert {m.name for m in DType} == set(tex.DType.__members__), ( | ||
| "DType is out of sync with transformer_engine_torch.DType; " | ||
| "add the new pybind enumerator to DType in constants.py." |
There was a problem hiding this comment.
This is good comment for TE devs, but if someone loads newer/older c++ library version of TE with different dtypes this error may be confusing.
There was a problem hiding this comment.
Have made the comment better now.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
| } | ||
| return members; | ||
| }(); | ||
| return cache[idx]; |
There was a problem hiding this comment.
MakePythonDType silently returns a null object on cache miss
cache[idx] is a default-constructed pybind11::object (null internal pointer) for any DType index not covered by the Python enum iteration — which would happen if the sync assertion in constants.py is stripped by Python's -O flag and a new C++ enum value is added without a matching DType member. Returning that null object propagates to callers such as kwargs["fp8_dtype"] = MakePythonDType(dtype), causing a null-pointer dereference or an opaque crash when the Python constructor receives the argument. The equivalent cache in dtype_pybind_conversion.h already guards against this with an explicit null check (if (cached == nullptr) { return handle(); }); the same guard belongs here.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch amd64 arm64 |
Description
Replace the pybind tex.DType with a canonical Python DType IntEnum throughout transformer_engine.pytorch. For backward compat, cpp->python and python->cpp DType object conversions are cached at the pybind boundaries to reduce CPU overheads.
Motivation
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: