Skip to content

Add MPS (Apple Silicon) support via jax-mps#69

Open
vlordier wants to merge 1 commit into
learnsyslab:mainfrom
vlordier:feat/mps-support
Open

Add MPS (Apple Silicon) support via jax-mps#69
vlordier wants to merge 1 commit into
learnsyslab:mainfrom
vlordier:feat/mps-support

Conversation

@vlordier

Copy link
Copy Markdown

Summary

Enables crazyflow to run on Apple Silicon Macs using the jax-mps plugin, which exposes Apple's Metal Performance Shaders as a JAX backend.

Changes

  • crazyflow/sim/sim.py — Pass impl='JAX' to mjx.put_model/ mjx.put_data when device.platform == 'mps'. MJX's _resolve_impl only recognizes cpu, gpu, and tpu platforms; forcing the JAX impl bypasses this check while still placing data on the MPS device.
  • pyproject.toml — Add mps = ["jax-mps"] optional dependency.
  • README.md — Document MPS pip install and usage.

Testing

All 331 tests (262 unit + 69 integration) pass on an M-series MacBook Pro with JAX_MPS_ASYNC_DISPATCH=1. The device='cpu' tests also pass since jax-mps auto-registers alongside the CPU backend.

Performance (first-principles physics, 1 drone)

n_worlds Steps/s (MPS)
64 204
1024 38,304

@vlordier vlordier requested a review from amacati as a code owner June 16, 2026 19:37
Copilot AI review requested due to automatic review settings June 16, 2026 19:37

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds Apple Silicon (MPS) support by introducing an mps extra dependency, documenting installation/usage, and tweaking MJX model/data creation to work with jax-mps.

Changes:

  • Add mps optional dependency (jax-mps) in pyproject.toml.
  • Update Sim.build_mjx_model to force the JAX implementation when running on MPS.
  • Document MPS installation and usage in the README.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
pyproject.toml Adds an mps extra to install jax-mps.
crazyflow/sim/sim.py Forces MJX to use the JAX impl for MPS devices.
README.md Documents crazyflow[mps] installation and MPS runtime notes.

Comment thread crazyflow/sim/sim.py Outdated
Comment on lines +238 to +241
# MJX does not recognize 'mps' as a platform; force JAX impl to work with jax-mps.
impl = "JAX" if self.device.platform == "mps" else None
mjx_model = mjx.put_model(mj_model, device=self.device, impl=impl)
mjx_data = mjx.put_data(mj_model, mj_data, device=self.device, impl=impl)
Comment thread pyproject.toml

[project.optional-dependencies]
gpu = ["jax[cuda12]"]
mps = ["jax-mps"]
Comment thread crazyflow/sim/sim.py Outdated
Comment on lines +239 to +241
impl = "JAX" if self.device.platform == "mps" else None
mjx_model = mjx.put_model(mj_model, device=self.device, impl=impl)
mjx_data = mjx.put_data(mj_model, mj_data, device=self.device, impl=impl)
Always pass impl='JAX' to mjx.put_model/put_data — this works
with any JAX backend (cpu, gpu, tpu, mps) and avoids MJX's
platform whitelist check that rejects unknown platforms.

Add mps optional dependency to pyproject.toml.
Document MPS usage in README.
@amacati

amacati commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Hey @vlordier , thanks for the PR. As far as I know, there is a jax-metal package for Apple silicon, right? Why not go for that one instead, since it is officially supported by Apple?

@amacati amacati added enhancement New feature or request Apple Silicon labels Jun 17, 2026
@amacati amacati self-assigned this Jun 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Silicon enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants