From e1e646567b1dc9a5e79d65548d7338b88b50bd20 Mon Sep 17 00:00:00 2001 From: vlordier Date: Tue, 16 Jun 2026 21:46:04 +0200 Subject: [PATCH] Add MPS (Apple Silicon) support via jax-mps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Always pass impl='JAX' to mjx.put_model/put_data — this works with any JAX backend (cpu, gpu, tpu, mps) and avoids MJX's platform whitelist check that rejects unknown platforms. Add mps optional dependency to pyproject.toml. Document MPS usage in README. --- README.md | 3 +++ crazyflow/sim/sim.py | 6 ++++-- pyproject.toml | 1 + 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 236a415..26f831c 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ for _ in range(100): ```bash pip install crazyflow # CPU pip install "crazyflow[gpu]" # GPU (Linux x86-64, CUDA 12) +pip install "crazyflow[mps]" # Apple Silicon (MPS via jax-mps) ``` Developer install with editable submodules ([pixi](https://pixi.sh/) required): @@ -72,6 +73,8 @@ First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 409 Full benchmarks including multi-drone scaling are in the [documentation](https://learnsyslab.github.io/crazyflow). +> **Apple Silicon:** Crazyflow also runs on the MPS backend via `jax-mps`. Pass `device='mps'` when creating a `Sim`, or set `JAX_MPS_ASYNC_DISPATCH=1` for async dispatch. Use `pip install "crazyflow[mps]"` or `pixi run pip install jax-mps` inside the pixi environment. Under pixi, jax-mps auto-registers as the default backend with CPU fallback, so `device='cpu'` tests also pass without extra configuration. + ## Related packages Crazyflow is built on two companion packages that can also be used independently: diff --git a/crazyflow/sim/sim.py b/crazyflow/sim/sim.py index 644f62b..0da94ea 100644 --- a/crazyflow/sim/sim.py +++ b/crazyflow/sim/sim.py @@ -235,8 +235,10 @@ def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]: """Build the MuJoCo model and data structures for the simulation.""" mj_model = spec.compile() mj_data = mujoco.MjData(mj_model) - mjx_model = mjx.put_model(mj_model, device=self.device) - mjx_data = mjx.put_data(mj_model, mj_data, device=self.device) + # Always use the JAX implementation directly to support + # non-standard backends such as jax-mps (Apple Silicon). + mjx_model = mjx.put_model(mj_model, device=self.device, impl="JAX") + mjx_data = mjx.put_data(mj_model, mj_data, device=self.device, impl="JAX") mjx_data = jax.vmap(lambda _: mjx_data)(jnp.arange(self.n_worlds)) return mj_model, mj_data, mjx_model, mjx_data diff --git a/pyproject.toml b/pyproject.toml index 8a54bdf..a6658a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ requires-python = ">=3.11,<3.14" # MuJoCo has no wheels for python 14 [project.optional-dependencies] gpu = ["jax[cuda12]"] +mps = ["jax-mps"] benchmark = ["fire", "matplotlib", "pandas"] [project.urls]