-
Notifications
You must be signed in to change notification settings - Fork 27
Add AMD ROCm (gfx942) support for the image→3D generation stack #72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| # EmbodiedGen on AMD ROCm (gfx942 / MI300X / MI308X), ROCm 6.4.3 + PyTorch 2.6. | ||
| # | ||
| # Builds the FULL image-to-3D generation stack on ROCm by swapping the CUDA-only | ||
| # libraries for verified ROCm equivalents (rocm-lib-compat / ZJLi2013/rocm3d): | ||
| # spconv-cu120 -> spconv_rocm | nvdiffrast -> nvdiffrast@rocm | gsplat -> amd_gsplat | ||
| # pytorch3d -> ROCm 6.4/py3.12 wheel | flash-attn -> FA2-Triton | numpy -> <2 | ||
| # kaolin (CUDA-only) -> sitecustomize stub (texture/mesh-IO stage only) | ||
| # Verified on AMD Instinct MI308X: SAM3D + TRELLIS pipelines import & initialize | ||
| # (spconv backend + flash_attn; SAM3D attention -> sdpa). See docker/amd_rocm/README.md. | ||
| # | ||
| # Build (from repo root, submodules checked out): | ||
| # docker build -f docker/amd_rocm/Dockerfile -t embodiedgen:rocm6.4.3 . | ||
| # Run img3d (GPT-free / no texture-bake smoke): | ||
| # docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \ | ||
| # --shm-size 32g -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \ | ||
| # embodiedgen:rocm6.4.3 python -m embodied_gen.models.sam3d | ||
|
|
||
| FROM rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0 | ||
|
|
||
| ENV DEBIAN_FRONTEND=noninteractive \ | ||
| PYTHONUNBUFFERED=1 \ | ||
| PIP_NO_CACHE_DIR=1 \ | ||
| PIP_ROOT_USER_ACTION=ignore \ | ||
| PYTORCH_ROCM_ARCH=gfx942 \ | ||
| GPU_ARCHS=gfx942 \ | ||
| FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \ | ||
| PYTHONPATH=/workspace/EmbodiedGen | ||
|
|
||
| WORKDIR /workspace/EmbodiedGen | ||
|
|
||
| # Source tree (incl. thirdparty/TRELLIS, thirdparty/sam3d submodules) is required | ||
| # for install_rocm.sh (cleans requirements.txt, builds extensions, installs stub). | ||
| COPY . /workspace/EmbodiedGen | ||
|
|
||
| # Install the ROCm generation stack. Compiles spconv_rocm / nvdiffrast / flash-attn | ||
| # from source (hipcc), so this layer is the slow one. | ||
| RUN bash docker/amd_rocm/install_rocm.sh | ||
|
|
||
| # Smoke: the full img3d import+init chain on ROCm (no model download / no GPT). | ||
| CMD ["python", "-c", "import embodied_gen.models.sam3d, thirdparty.TRELLIS.trellis.pipelines; print('EmbodiedGen ROCm generation image OK')"] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| # EmbodiedGen on AMD ROCm (MI300X / MI308X) | ||
|
|
||
| Run the EmbodiedGen **image-to-3D generation** stack on AMD GPUs (gfx942) with | ||
| ROCm 6.4.3 + PyTorch 2.6, by swapping the CUDA-only libraries for verified ROCm | ||
| equivalents. Verified on **AMD Instinct MI308X**: the SAM3D and TRELLIS pipelines | ||
| import and initialize (spconv backend + flash-attn; SAM3D attention auto-selects | ||
| `sdpa`). | ||
|
|
||
| > Library swaps follow the `rocm-lib-compat` reference | ||
| > ([ZJLi2013/rocm3d](https://github.com/ZJLi2013/rocm3d)). The same TRELLIS-v1 | ||
| > stack is independently verified there via `VAST-AI/AniGen`. | ||
|
|
||
| ## TL;DR | ||
|
|
||
| ```bash | ||
| # from repo root, with submodules checked out: | ||
| git submodule update --init --recursive | ||
| docker build -f docker/amd_rocm/Dockerfile -t embodiedgen:rocm6.4.3 . | ||
|
|
||
| # import+init smoke (no download / no GPT): | ||
| docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \ | ||
| --shm-size 32g -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \ | ||
| embodiedgen:rocm6.4.3 | ||
|
|
||
| # full GPT-free image->3D (downloads facebook/sam-3d-objects, ~15GB; saves splat.ply): | ||
| docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video \ | ||
| --shm-size 32g -e FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE \ | ||
| -v $PWD:/workspace/EmbodiedGen embodiedgen:rocm6.4.3 \ | ||
| python -m embodied_gen.models.sam3d | ||
| ``` | ||
|
|
||
| To run the swaps directly in a base container instead of building the image: | ||
|
|
||
| ```bash | ||
| docker run -it --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 32g \ | ||
| -v $PWD:/workspace/EmbodiedGen -w /workspace/EmbodiedGen \ | ||
| rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0 \ | ||
| bash docker/amd_rocm/install_rocm.sh | ||
| ``` | ||
|
|
||
| ## CUDA -> ROCm dependency map | ||
|
|
||
| | Upstream (CUDA) | ROCm replacement | Status on MI308X | | ||
| |---|---|---| | ||
| | `spconv-cu120` | [`ZJLi2013/spconv_rocm`](https://github.com/ZJLi2013/spconv_rocm) (source) | ✅ import OK | | ||
| | `nvdiffrast` | [`ZJLi2013/nvdiffrast@rocm`](https://github.com/ZJLi2013/nvdiffrast) | ✅ import OK | | ||
| | `gsplat` | `amd_gsplat` (`pypi.amd.com/rocm-6.4.3`), import name `gsplat` | ✅ default GS backend | | ||
| | `pytorch3d` | ROCm 6.4 / py3.12 prebuilt wheel | ✅ import OK | | ||
| | `flash-attn` | FA2-Triton (`FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE` at install **and** runtime) | ✅ import OK | | ||
| | `xformers` | not needed — SAM3D attention auto-selects `sdpa` | ✅ skipped | | ||
| | `numpy` (base ships 2.x) | pin `numpy<2` (diffusers/transformers requirement) | ✅ | | ||
| | `kaolin` (no ROCm wheel; setup.py requires `nvcc`) | `sitecustomize` stub (`docker/amd_rocm/kaolin_stub.py`) | ⚠️ texture-stage only | | ||
| | `diff-gaussian-rasterization` | optional 'inria' GS backend (gsplat is default) | ⏸ optional | | ||
|
|
||
| ## The kaolin stub (`docker/amd_rocm/kaolin_stub.py`) | ||
|
|
||
| `kaolin` is CUDA-only and is imported at the top of `embodied_gen/data/utils.py`, | ||
| but is only **used** inside the texture-backprojection / mesh-IO stage | ||
| (`kal.io.*.import_mesh`, `kal.render.materials.PBRMaterial`, | ||
| `kaolin.render.camera.Camera`) and as type references in `thirdparty/sam3d`. | ||
| None of it is on the core geometry+gaussian generation path. The stub (installed | ||
| as `sitecustomize.py`) fabricates any `kaolin.*` module so every `import kaolin` | ||
| resolves; the texture stage raises a clear error if actually invoked. This mirrors | ||
| the proven `ZJLi2013/RealWonder` bypass (~85% pipeline usable on ROCm). | ||
|
|
||
| The upstream-friendly long-term fix is to make the kaolin imports in | ||
| `data/utils.py` lazy/optional so the stub is unnecessary. | ||
|
|
||
| ## Known gaps | ||
|
|
||
| - **Texture backprojection** (`backproject_v3` / `differentiable_render`) calls | ||
| real kaolin mesh-IO and is not available under the stub. Core image-to-3D | ||
| (segmentation -> SAM3D geometry + gaussian + mesh export) runs without it. | ||
| - **GPT quality-checkers / URDF semantics** (`img3d-cli`) need a GPT key; the | ||
| `python -m embodied_gen.models.sam3d` path skips them entirely. | ||
| - **`diff-gaussian-rasterization`** (mip-splatting / antialiasing fork) needs | ||
| `__trap`->`abort` and a `cooperative_groups/reduce.h` shim to build on ROCm; | ||
| it is optional because `gsplat` is the default gaussian backend. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| #!/bin/bash | ||
| # EmbodiedGen ROCm install (gfx942 / MI300X / MI308X), ROCm 6.4.3 + PyTorch 2.6. | ||
| # Swaps the CUDA-only generation stack for ROCm-compatible builds following the | ||
| # rocm-lib-compat reference (github.com/ZJLi2013/rocm3d). Intended to run inside | ||
| # rocm/pytorch:rocm6.4.3_ubuntu24.04_py3.12_pytorch_release_2.6.0 with the repo at | ||
| # /workspace/EmbodiedGen. Each step reports PASS/FAIL so one run yields a full | ||
| # ROCm-compat status map; the script exits non-zero at the end if any required | ||
| # step (or core import in the smoke check) failed, so a Docker build aborts. | ||
| set -uo pipefail | ||
|
|
||
| export PYTORCH_ROCM_ARCH=gfx942 | ||
| export GPU_ARCHS=gfx942 | ||
| export FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE | ||
| export PIP_ROOT_USER_ACTION=ignore | ||
| REPO=${REPO:-/workspace/EmbodiedGen} | ||
| cd "$REPO" | ||
|
|
||
| PASS=(); FAIL=() | ||
| step () { # step "name" cmd... | ||
| local name="$1"; shift | ||
| echo "==================== STEP: $name ====================" | ||
| if "$@"; then echo "[PASS] $name"; PASS+=("$name"); | ||
| else echo "[FAIL] $name"; FAIL+=("$name"); fi | ||
| } | ||
|
HochCC marked this conversation as resolved.
|
||
| pipi () { pip install --no-cache-dir "$@"; } | ||
|
|
||
| # --- 0. keep ROCm torch from base image (do NOT reinstall cu118 torch) --- | ||
| python -c "import torch;print('base torch',torch.__version__,'hip',torch.version.hip,'gpu',torch.cuda.is_available())" | ||
|
|
||
| # --- 1. requirements.txt minus CUDA-pinned libs (handled below or via base) --- | ||
| EXCLUDE='torch|torchvision|torchaudio|xformers|gsplat|flash.attn|flash-attn|triton|spconv|spconv-cu120|pytorch3d' | ||
| grep -vEi "^(${EXCLUDE})([<>=!~;[:space:]]|$)" requirements.txt > /tmp/req_clean.txt | ||
| echo "--- cleaned requirements (CUDA libs stripped) ---"; cat /tmp/req_clean.txt | ||
| step "requirements(clean)" pipi -r /tmp/req_clean.txt --use-deprecated=legacy-resolver | ||
| # numpy: EmbodiedGen's diffusers/transformers REQUIRE numpy<2. (The rocm-lib-compat | ||
| # "use docker numpy 2.x" guidance does NOT apply here; the base image ships numpy 2.x.) | ||
| step "numpy<2" pipi "numpy<2" | ||
| # NOTE: xformers is NOT required -- SAM3D attention auto-selects the `sdpa` backend on | ||
| # ROCm, and TRELLIS uses spconv+flash_attn. Skipping xformers avoids the torch 2.9.1 | ||
| # bump that would break the pytorch3d/gsplat ROCm 6.4 wheels. | ||
|
|
||
| # --- 2. ROCm replacements for the CUDA-only generation stack (rocm-lib-compat) --- | ||
| # All verified on gfx942 / MI300X, ROCm 6.4 (same stack as VAST-AI/AniGen, | ||
| # a TRELLIS-v1 image-to-3D repo, in the rocm3d supported-repo list). | ||
| # spconv (CUDA spconv-cu120 -> ZJLi2013/spconv_rocm) | ||
| step "spconv_rocm" bash -c ' | ||
| rm -rf /tmp/spconv_rocm && | ||
| git clone --depth 1 -b rocm https://github.com/ZJLi2013/spconv_rocm.git /tmp/spconv_rocm && | ||
| pip install --no-cache-dir -e /tmp/spconv_rocm' | ||
| # nvdiffrast (NVlabs -> ZJLi2013/nvdiffrast@rocm) | ||
| step "nvdiffrast_rocm" pipi "git+https://github.com/ZJLi2013/nvdiffrast.git@rocm" --no-build-isolation | ||
| # gsplat (-> amd_gsplat prebuilt; import name stays `gsplat`; default gaussian backend) | ||
| step "amd_gsplat" pipi amd_gsplat --extra-index-url=https://pypi.amd.com/rocm-6.4.3/simple/ | ||
| # pytorch3d (-> prebuilt ROCm 6.4 / py3.12 wheel) | ||
| step "pytorch3d_rocm" pipi https://github.com/ZJLi2013/pytorch3d/releases/download/rocm6.4-py3.12/pytorch3d-0.7.9-cp312-cp312-linux_x86_64.whl | ||
| # flash-attn (FA2 Triton on ROCm 6.4). NOTE: requires FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE | ||
| # at BOTH install and runtime; otherwise import falls back to the CUDA `flash_attn_2_cuda`. | ||
| step "flash_attn(triton)" bash -c 'FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE pip install --no-cache-dir flash-attn --no-build-isolation' | ||
|
|
||
| # --- 3. pure git deps from upstream install_basic.sh (CUDA-agnostic) --- | ||
| step "utils3d" pipi "utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15" | ||
| step "clip" pipi "clip@git+https://github.com/openai/CLIP.git" | ||
| step "segment_anything" pipi "segment-anything@git+https://github.com/facebookresearch/segment-anything.git@dca509f" | ||
| step "kolors" pipi "kolors@git+https://github.com/HochCC/Kolors.git" | ||
| step "MoGe" pipi "MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734" | ||
|
|
||
| # --- 4. OPTIONAL: diff-gaussian-rasterization ('inria' gaussian backend) --- | ||
| # img3d's default gaussian backend is gsplat (amd_gsplat, above), and both TRELLIS | ||
| # and SAM3D guard the diff_gaussian_rasterization import in try/except, so this is | ||
| # optional. The CUDA-clean ROCm source is graphdeco-inria built via expenses/ | ||
| # gaussian-splatting's ROCm branch (rocm3d supported-repo list). EmbodiedGen wires | ||
| # TRELLIS to the *mip-splatting* antialiasing fork, whose ROCm build additionally | ||
| # needs `__trap`->abort and a cooperative_groups/reduce.h shim (PR candidate). | ||
| # Left out of the default install; uncomment to add the non-AA 'inria' backend: | ||
| # touch /opt/rocm/include/device_launch_parameters.h | ||
| # step "diff_gaussian_rasterization" bash -c ' | ||
| # rm -rf /tmp/dgr && | ||
| # git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization /tmp/dgr && | ||
| # cd /tmp/dgr && git submodule update --init --recursive && | ||
| # PYTORCH_ROCM_ARCH=gfx942 pip install --no-cache-dir . --no-build-isolation' | ||
|
|
||
| # --- 5. ROCm runtime shims, installed as sitecustomize (run at interpreter startup) --- | ||
| # (a) kaolin bypass: kaolin is CUDA-only (no ROCm wheel, setup.py hard-requires nvcc). | ||
| # Imported at module top of embodied_gen/data/utils.py but only *used* inside the | ||
| # texture-backprojection / mesh-IO stage (kal.io.*.import_mesh, render.materials, | ||
| # render.camera) + type refs in thirdparty/sam3d -- none on the core image->3D | ||
| # geometry+gaussian path. Stub pattern proven on ZJLi2013/RealWonder (same | ||
| # SAM-3D-Objects/kaolin dep). check_tensor-style validators return truthy. | ||
| # (b) spconv KRSC->Native weight bridge: SAM3D/TRELLIS checkpoints store sparse-conv | ||
| # weights in CUDA spconv's ImplicitGemm KRSC layout (5D [out,k,k,k,in]); spconv_rocm | ||
| # falls back to the Native algo (3D [Kvol,in,out], 2D when kvol==1), so load_state_dict | ||
| # mismatches. The shim converts on load. (Upstream fix belongs in spconv_rocm.) | ||
| SITE=$(python -c "import site;print(site.getsitepackages()[0])") | ||
| HERE="$(dirname "$0")" | ||
| if cp "$HERE/kaolin_stub.py" "$SITE/kaolin_stub.py" \ | ||
| && cp "$HERE/spconv_rocm_compat.py" "$SITE/spconv_rocm_compat.py" \ | ||
| && printf 'import kaolin_stub\nimport spconv_rocm_compat\n' > "$SITE/sitecustomize.py"; then | ||
| echo "[PASS] rocm-shims -> $SITE/sitecustomize.py (kaolin_stub + spconv_rocm_compat)"; PASS+=("rocm-shims") | ||
| else | ||
| echo "[FAIL] rocm-shims copy"; FAIL+=("rocm-shims") | ||
| fi | ||
|
|
||
| # --- 6. import smoke: what actually loads on ROCm (flash-attn needs the env var) --- | ||
| # Core imports failing is fatal (exit 1); optional ones only warn. | ||
| echo "==================== IMPORT SMOKE ====================" | ||
| if ! FLASH_ATTENTION_TRITON_AMD_ENABLE=TRUE python - <<'PY' | ||
| import importlib, sys | ||
| core = ["torch","spconv","nvdiffrast.torch","gsplat","pytorch3d","flash_attn","trimesh","diffusers"] | ||
| optional = ["diff_gaussian_rasterization","kaolin"] | ||
| def check(m): | ||
| try: | ||
| importlib.import_module(m); print(f"[import OK ] {m}"); return True | ||
| except Exception as e: | ||
| print(f"[import ERR] {m}: {type(e).__name__}: {str(e)[:160]}"); return False | ||
| print("-- core --"); core_ok = all([check(m) for m in core]) | ||
| print("-- optional --"); [check(m) for m in optional] | ||
| import torch | ||
| print("torch", torch.__version__, "hip", torch.version.hip, "gpu", torch.cuda.is_available()) | ||
| sys.exit(0 if core_ok else 1) | ||
| PY | ||
| then | ||
| echo "[FAIL] import-smoke(core)"; FAIL+=("import-smoke") | ||
| fi | ||
|
|
||
| echo "==================== SUMMARY ====================" | ||
| echo "PASS (${#PASS[@]}): ${PASS[*]}" | ||
| echo "FAIL (${#FAIL[@]}): ${FAIL[*]}" | ||
| echo "INSTALL_ROCM_DONE" | ||
|
|
||
| # Fail the Docker build if any required (non-optional) step failed. The two | ||
| # `optional` items (diff_gaussian_rasterization, kaolin) stay non-fatal by design. | ||
| if [ "${#FAIL[@]}" -ne 0 ]; then | ||
| echo "ERROR: required ROCm step(s) failed: ${FAIL[*]}" >&2 | ||
| exit 1 | ||
| fi | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| """ROCm kaolin bypass for EmbodiedGen (generalized from ZJLi2013/RealWonder). | ||
|
|
||
| kaolin is CUDA-only (no ROCm wheel; setup.py hard-requires nvcc). In EmbodiedGen | ||
| it is imported at module top of `embodied_gen/data/utils.py` and used only inside | ||
| the texture-backprojection / differentiable-render stage (`kal.io.*.import_mesh`, | ||
| `kal.render.materials.PBRMaterial`, `kaolin.render.camera.Camera`), plus type-level | ||
| references in thirdparty/sam3d. None of it is on the core image->3D geometry+gaussian | ||
| generation path (gsplat is the gaussian backend), so stubbing `kaolin` lets img3d-cli | ||
| run on ROCm. The texture-baking stage that actually calls these will surface a clear | ||
| error instead of crashing every import. | ||
|
|
||
| Activation (must run before any `import kaolin`): | ||
| - drop this file's directory on PYTHONPATH as `sitecustomize.py`, or | ||
| - `import kaolin_stub` as the very first import of the entrypoint. | ||
|
|
||
| This is the ROCm-unblock shim; the upstream-PR-appropriate fix is to make the kaolin | ||
| imports in `data/utils.py` lazy/optional (see docs/exp18.md). | ||
| """ | ||
| import importlib.abc | ||
| import importlib.machinery | ||
| import sys | ||
| import types | ||
|
|
||
|
|
||
| class _KaolinStubModule(types.ModuleType): | ||
| def __init__(self, name): | ||
| super().__init__(name) | ||
| self.__file__ = "<kaolin-stub>" | ||
| self.__path__ = [] | ||
| self.__spec__ = None | ||
|
|
||
| def __getattr__(self, name): | ||
| # Let dunder lookups (e.g. __file__, __wrapped__, __all__) behave normally. | ||
| if name.startswith("__") and name.endswith("__"): | ||
| raise AttributeError(name) | ||
| # Capitalized -> isinstance-safe stub class (Camera, PBRMaterial, ...). | ||
| if name and name[0].isupper(): | ||
| return type(name, (), {}) | ||
| # Lowercase -> a no-op callable that also behaves like a submodule. | ||
| stub = _KaolinCallableStub(f"{self.__name__}.{name}") | ||
| return stub | ||
|
|
||
|
|
||
| class _KaolinCallableStub(_KaolinStubModule): | ||
| def __call__(self, *args, **kwargs): | ||
| # Return a truthy no-op. The only kaolin calls on the core geometry/mesh path | ||
| # are validators like `kaolin.utils.testing.check_tensor(...)`, used inside | ||
| # `assert torch.is_tensor(x) and check_tensor(...)`, which need a truthy return. | ||
| # Data-returning kaolin calls (kal.io.*.import_mesh, render.*) live in the | ||
| # texture-backprojection stage and will fail fast downstream (documented gap). | ||
| return True | ||
|
|
||
|
|
||
| class _KaolinFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader): | ||
| def find_spec(self, name, path=None, target=None): | ||
| if name == "kaolin" or name.startswith("kaolin."): | ||
| return importlib.machinery.ModuleSpec(name, self) | ||
| return None | ||
|
|
||
| def create_module(self, spec): | ||
| return _KaolinStubModule(spec.name) | ||
|
|
||
| def exec_module(self, module): | ||
| pass | ||
|
|
||
|
|
||
| if "kaolin" not in sys.modules and not any( | ||
| isinstance(f, _KaolinFinder) for f in sys.meta_path | ||
| ): | ||
| sys.meta_path.insert(0, _KaolinFinder()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| """spconv KRSC->Native weight bridge for ROCm. | ||
|
|
||
| CUDA spconv 2.3.x defaults to the ImplicitGemm conv algo, whose `SparseConvolution` | ||
| weights are stored in KRSC layout: 5D `[out_channels, kd, kh, kw, in_channels]`. | ||
| ROCm `spconv_rocm` (2.3.8+rocm1) lacks the implicit-gemm kernels and falls back to | ||
| the Native algo, whose weights are 3D `[kernel_volume, in_channels, out_channels]`. | ||
|
|
||
| So checkpoints trained on CUDA (e.g. facebook/sam-3d-objects, SAM3D, TRELLIS) fail to | ||
| load on ROCm with errors like: | ||
| size mismatch ... copying a param with shape [128, 3, 3, 3, 128] | ||
| the shape in current model is [27, 128, 128] | ||
|
|
||
| This patches `torch.nn.Module.load_state_dict` to transparently convert any 5D KRSC | ||
| spconv weight into the 3D Native layout when (and only when) the destination model | ||
| parameter is the matching 3D shape. It is a no-op on CUDA / already-native weights. | ||
|
|
||
| Activation: import before loading any spconv checkpoint (e.g. via sitecustomize). | ||
| This is the ROCm-unblock shim; the upstream-appropriate fix belongs in spconv_rocm | ||
| (accept KRSC checkpoints under the Native algo). | ||
| """ | ||
| import torch | ||
|
|
||
| _orig_load_state_dict = torch.nn.Module.load_state_dict | ||
|
|
||
|
|
||
| def _krsc_to_native(w: torch.Tensor): | ||
| # KRSC [out, kd, kh, kw, in] -> Native [kd*kh*kw, in, out]; the Native algo squeezes | ||
| # kernel_volume==1 (1x1x1 conv) to 2D [in, out]. | ||
| out, kd, kh, kw, inc = w.shape | ||
| kvol = kd * kh * kw | ||
| native = w.permute(1, 2, 3, 4, 0).contiguous().reshape(kvol, inc, out) | ||
| return native, kvol, inc, out | ||
|
|
||
|
|
||
| def _patched_load_state_dict(self, state_dict, strict=True, *args, **kwargs): | ||
| try: | ||
| own = self.state_dict() | ||
| except Exception: | ||
| return _orig_load_state_dict(self, state_dict, strict=strict, *args, **kwargs) | ||
|
|
||
| converted = 0 | ||
| fixed = dict(state_dict) | ||
| for name, val in state_dict.items(): | ||
| tgt = own.get(name) | ||
| if tgt is None or not hasattr(val, "ndim") or val.ndim != 5: | ||
| continue | ||
| native, kvol, inc, out = _krsc_to_native(val) | ||
| if tgt.ndim == 3 and tuple(tgt.shape) == (kvol, inc, out): | ||
| fixed[name] = native | ||
| converted += 1 | ||
| elif tgt.ndim == 2 and kvol == 1 and tuple(tgt.shape) == (inc, out): | ||
| fixed[name] = native.reshape(inc, out) | ||
| converted += 1 | ||
| if converted: | ||
| print(f"[spconv-rocm-compat] converted {converted} KRSC->Native conv weights") | ||
| return _orig_load_state_dict(self, fixed, strict=strict, *args, **kwargs) | ||
|
|
||
|
|
||
| if getattr(torch.nn.Module.load_state_dict, "__name__", "") != "_patched_load_state_dict": | ||
| torch.nn.Module.load_state_dict = _patched_load_state_dict |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.