diff --git a/README.md b/README.md index 236a415..26f831c 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ for _ in range(100): ```bash pip install crazyflow # CPU pip install "crazyflow[gpu]" # GPU (Linux x86-64, CUDA 12) +pip install "crazyflow[mps]" # Apple Silicon (MPS via jax-mps) ``` Developer install with editable submodules ([pixi](https://pixi.sh/) required): @@ -72,6 +73,8 @@ First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 409 Full benchmarks including multi-drone scaling are in the [documentation](https://learnsyslab.github.io/crazyflow). +> **Apple Silicon:** Crazyflow also runs on the MPS backend via `jax-mps`. Pass `device='mps'` when creating a `Sim`, or set `JAX_MPS_ASYNC_DISPATCH=1` for async dispatch. Use `pip install "crazyflow[mps]"` or `pixi run pip install jax-mps` inside the pixi environment. Under pixi, jax-mps auto-registers as the default backend with CPU fallback, so `device='cpu'` tests also pass without extra configuration. + ## Related packages Crazyflow is built on two companion packages that can also be used independently: diff --git a/crazyflow/sim/sim.py b/crazyflow/sim/sim.py index 644f62b..0da94ea 100644 --- a/crazyflow/sim/sim.py +++ b/crazyflow/sim/sim.py @@ -235,8 +235,10 @@ def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]: """Build the MuJoCo model and data structures for the simulation.""" mj_model = spec.compile() mj_data = mujoco.MjData(mj_model) - mjx_model = mjx.put_model(mj_model, device=self.device) - mjx_data = mjx.put_data(mj_model, mj_data, device=self.device) + # Always use the JAX implementation directly to support + # non-standard backends such as jax-mps (Apple Silicon). + mjx_model = mjx.put_model(mj_model, device=self.device, impl="JAX") + mjx_data = mjx.put_data(mj_model, mj_data, device=self.device, impl="JAX") mjx_data = jax.vmap(lambda _: mjx_data)(jnp.arange(self.n_worlds)) return mj_model, mj_data, mjx_model, mjx_data diff --git a/pyproject.toml b/pyproject.toml index 8a54bdf..a6658a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ requires-python = ">=3.11,<3.14" # MuJoCo has no wheels for python 14 [project.optional-dependencies] gpu = ["jax[cuda12]"] +mps = ["jax-mps"] benchmark = ["fire", "matplotlib", "pandas"] [project.urls]