Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ for _ in range(100):
```bash
pip install crazyflow # CPU
pip install "crazyflow[gpu]" # GPU (Linux x86-64, CUDA 12)
pip install "crazyflow[mps]" # Apple Silicon (MPS via jax-mps)
```

Developer install with editable submodules ([pixi](https://pixi.sh/) required):
Expand All @@ -72,6 +73,8 @@ First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 409

Full benchmarks including multi-drone scaling are in the [documentation](https://learnsyslab.github.io/crazyflow).

> **Apple Silicon:** Crazyflow also runs on the MPS backend via `jax-mps`. Pass `device='mps'` when creating a `Sim`, or set `JAX_MPS_ASYNC_DISPATCH=1` for async dispatch. Use `pip install "crazyflow[mps]"` or `pixi run pip install jax-mps` inside the pixi environment. Under pixi, jax-mps auto-registers as the default backend with CPU fallback, so `device='cpu'` tests also pass without extra configuration.

## Related packages

Crazyflow is built on two companion packages that can also be used independently:
Expand Down
6 changes: 4 additions & 2 deletions crazyflow/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,10 @@ def build_mjx_model(self, spec: mujoco.MjSpec) -> tuple[Any, Any, Model, Data]:
"""Build the MuJoCo model and data structures for the simulation."""
mj_model = spec.compile()
mj_data = mujoco.MjData(mj_model)
mjx_model = mjx.put_model(mj_model, device=self.device)
mjx_data = mjx.put_data(mj_model, mj_data, device=self.device)
# Always use the JAX implementation directly to support
# non-standard backends such as jax-mps (Apple Silicon).
mjx_model = mjx.put_model(mj_model, device=self.device, impl="JAX")
mjx_data = mjx.put_data(mj_model, mj_data, device=self.device, impl="JAX")
mjx_data = jax.vmap(lambda _: mjx_data)(jnp.arange(self.n_worlds))
return mj_model, mj_data, mjx_model, mjx_data

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ requires-python = ">=3.11,<3.14" # MuJoCo has no wheels for python 14

[project.optional-dependencies]
gpu = ["jax[cuda12]"]
mps = ["jax-mps"]
benchmark = ["fire", "matplotlib", "pandas"]

[project.urls]
Expand Down
Loading