Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4fcdfd0
Merge models back into crazyflow
amacati Jun 15, 2026
a1321af
Introduce consistent naming scheme for dynamics/physics/models
amacati Jun 15, 2026
17a9f2b
Improve naming consistency in docs
amacati Jun 16, 2026
d7e2bee
Finalize consistent naming for drones/models/physics/dynamics
amacati Jun 17, 2026
1a228e9
Rename dynamics file, simplify so_rpy model, simplify models wrapper
amacati Jun 17, 2026
347dded
Remove drone-models submodule
amacati Jun 17, 2026
cb7ec8d
Simplify wrappers, add drone-models test suite to crazyflow
amacati Jun 17, 2026
afa2d3e
Improve tests
amacati Jun 17, 2026
56a9a4f
Fix test imports
amacati Jun 17, 2026
e7e2e06
Rename dynamics docs
amacati Jun 17, 2026
c677f81
Add drone model docs
amacati Jun 17, 2026
026350d
Fix docs
amacati Jun 17, 2026
129fc60
Add controllers back into crazyflow
amacati Jun 17, 2026
ae62616
Improve API and add docs
amacati Jun 17, 2026
1b052ae
Explicitly name controllers
amacati Jun 17, 2026
cc2219f
Remove all submodules. Update deps
amacati Jun 17, 2026
3109cdf
Update doc mistakes and benchmarks
amacati Jun 18, 2026
0715579
Address comments
amacati Jun 18, 2026
8c76a9c
[WIP] Refactor parametrize
amacati Jun 18, 2026
a46fbae
Move fused models into regular xml
amacati Jun 18, 2026
a4bbc8a
Refactor parametrization
amacati Jun 18, 2026
f9aef46
Fix docs and comments
amacati Jun 19, 2026
9bcfac4
Fix model docs. Restructure into core dynamics_euler and wrapper dyna…
amacati Jun 19, 2026
f22b001
Fix doctests
amacati Jun 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions .gitmodules

This file was deleted.

21 changes: 5 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

</div>

Crazyflow is a research simulator for quadrotors. It runs batched, differentiable simulations on CPU and GPU via JAX, with analytical and abstracted models for the Crazyflie 2.x family.
Crazyflow is a research simulator for quadrotors. It runs batched, differentiable simulations on CPU and GPU via JAX, with analytical and abstracted dynamics for the Crazyflie 2.x family.

```python
import numpy as np
Expand All @@ -36,10 +36,10 @@ for _ in range(100):
## Features

- **n\_worlds x n\_drones** — batched over independent environments and multi-drone swarms simultaneously
- **GPU-accelerated** — up to 914 M steps/s on an RTX 4090 (first-principles physics, 262 K worlds)
- **GPU-accelerated** — up to 914 M steps/s on an RTX 4090 (first-principles dynamics, 262 K worlds)
- **Differentiable** — `jax.grad` works through the full dynamics and control pipeline
- **First-principles models** — physics model using first-principles equations and parameters identified from real-world measurements
- **Abstracted models** — three physics models fitted from real Crazyflie flight data
- **First-principles dynamics** — dynamics using first-principles equations and parameters identified from real-world measurements
- **Abstracted dynamics** — simplified dynamics in three flavors fitted from real Crazyflie flight data
- **Modular pipelines** — step and reset are tuples of plain JAX functions; insert anything, anywhere
- **MuJoCo integration** — onscreen and offscreen rendering, raycasting, and contact detection via MJX

Expand All @@ -60,7 +60,7 @@ pixi shell

## Performance

First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 4090.
First-principles dynamics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 4090.

| n\_worlds | CPU steps/s | GPU steps/s |
|---|---|---|
Expand All @@ -72,17 +72,6 @@ First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 409

Full benchmarks including multi-drone scaling are in the [documentation](https://learnsyslab.github.io/crazyflow).

## Related packages

Crazyflow is built on two companion packages that can also be used independently:

| Package | Description |
|---|---|
| [drone-models](https://github.com/learnsyslab/drone-models) | Drone dynamics models (first-principles and fitted) compatible with NumPy, JAX, and PyTorch. Used by Crazyflow as the physics backend. |
| [drone-controllers](https://github.com/learnsyslab/drone-controllers) | Reference controller implementations including the Mellinger geometric controller. Used by Crazyflow to provide the state and attitude control modes. |

Both are installed automatically as dependencies. For development, they are included as submodules in `submodules/` and installed in editable mode by the pixi environment.

## Citation

```bibtex
Expand Down
6 changes: 4 additions & 2 deletions benchmark/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax

from crazyflow import Sim
from crazyflow.control import Control
from crazyflow.dynamics import Dynamics


def main(cache: bool = False):
Expand All @@ -16,7 +18,7 @@ def main(cache: bool = False):

# Time initialization
start = time.perf_counter()
sim = Sim(n_worlds=1, n_drones=1, physics="sys_id", control="attitude")
sim = Sim(n_worlds=1, n_drones=1, dynamics=Dynamics.so_rpy, control=Control.attitude)
init_time = time.perf_counter() - start

# Time reset compilation
Expand All @@ -29,7 +31,7 @@ def main(cache: bool = False):
sim._step.lower(sim.data, 1).compile()
step_time = time.perf_counter() - start

print(f"Simulation startup times | {sim.physics} | {sim.control}")
print(f"Simulation startup times | {sim.dynamics} | {sim.control}")
print(f"Initialization: {init_time:.2f}s")
print(f"Reset: {reset_time:.2f}s")
print(f"Step: {step_time:.2f}s")
Expand Down
4 changes: 2 additions & 2 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def profile_gym_env_step(
num_envs=sim_config.n_worlds,
device=sim_config.device,
freq=sim_config.freq,
physics=sim_config.physics,
dynamics=sim_config.dynamics,
)

# Action for going up (in attitude control)
Expand Down Expand Up @@ -151,7 +151,7 @@ def main(device: str = "cpu", n_worlds_exp: int = 6):
sim_config = config_dict.ConfigDict()
sim_config.n_worlds = 1
sim_config.n_drones = 1
sim_config.physics = "first_principles"
sim_config.dynamics = "first_principles"
sim_config.control = "attitude"
sim_config.attitude_freq = 500
sim_config.device = device
Expand Down
7 changes: 3 additions & 4 deletions benchmark/op_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

def main():
"""Main entry point for profiling."""
sim = Sim(n_worlds=1, n_drones=1, physics="first_principles", control="attitude")
sim = Sim(n_worlds=1, n_drones=1, dynamics="first_principles", control="attitude")

compiled_reset = sim._reset.lower(sim.data, sim.default_data, None).compile()
compiled_step = sim._step.lower(sim.data, 1).compile()
op_count_reset = compiled_reset.cost_analysis()["flops"]
op_count_step = compiled_step.cost_analysis()["flops"]
print(f"Op counts:\n Reset: {op_count_reset}\n Step: {op_count_step}")
print(f"Reset cost analysis: {compiled_reset.cost_analysis()}")
print(f"Step cost analysis: {compiled_step.cost_analysis()}")


if __name__ == "__main__":
Expand Down
11 changes: 7 additions & 4 deletions benchmark/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):


def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str):
device = jax.devices(device)[0]

envs: ReachPosEnv = gymnasium.make_vec(
"DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config
"DroneReachPos-v0",
max_episode_time=10.0,
num_envs=sim_config.n_worlds,
dynamics=sim_config.dynamics,
freq=50,
device=device,
)

# Action for going up (in attitude control)
Expand Down Expand Up @@ -77,7 +80,7 @@ def main():
sim_config = config_dict.ConfigDict()
sim_config.n_worlds = 1
sim_config.n_drones = 1
sim_config.physics = "first_principles"
sim_config.dynamics = "first_principles"
sim_config.control = "attitude"
sim_config.device = device

Expand Down
6 changes: 4 additions & 2 deletions crazyflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import crazyflow.envs # noqa: F401, ensure gymnasium envs are registered
from crazyflow.control import Control
from crazyflow.sim import Physics, Sim
from crazyflow.drones import available_drones
from crazyflow.dynamics import Dynamics
from crazyflow.sim import Sim

__all__ = ["Sim", "Physics", "Control"]
__all__ = ["Sim", "Dynamics", "Control", "available_drones"]
__version__ = "0.2.0"
8 changes: 8 additions & 0 deletions crazyflow/_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""This file is to be removed as soon as a proper typing is available by the official array-api."""

from typing import Any, TypeAlias

import numpy.typing as npt

Array: TypeAlias = Any # To be changed to array_api_typing later
ArrayLike: TypeAlias = Array | npt.ArrayLike
27 changes: 25 additions & 2 deletions crazyflow/control/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
from crazyflow.control.control import Control
"""Implementations of onboard drone controllers in Python.

__all__ = ["Control"]
All controllers are implemented using the array API standard. This means that every controller is
agnostic to the choice of framework and supports e.g. NumPy, JAX, or PyTorch. We also implement all
controllers as pure functions to ensure that users can jit-compile them. All controllers use
broadcasting to support batching of arbitrary leading dimensions.

We reimplement the onboard controller for two reasons:
- We cannot use the C++ bindings of the firmware to differentiate through the onboard controller.
- We need to implement it with JAX to enable efficient, batched computations.
"""

from typing import Callable

__all__ = []

from crazyflow.control.core import Control, parametrize
from crazyflow.control.mellinger import attitude2force_torque as mellinger_attitude2force_torque
from crazyflow.control.mellinger import state2attitude as mellinger_state2attitude

available_controller: dict[str, Callable] = {
"mellinger_state2attitude": mellinger_state2attitude,
"mellinger_attitude2force_torque": mellinger_attitude2force_torque,
}

__all__ = ["Control", "parametrize"]
67 changes: 0 additions & 67 deletions crazyflow/control/control.py

This file was deleted.

134 changes: 134 additions & 0 deletions crazyflow/control/core.py
Comment thread
amacati marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""Core functionalities for controller parametrization."""

from __future__ import annotations

import tomllib
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Callable, ParamSpec, TypeVar

import jax
from jax import Array

from crazyflow.utils import filter_to_signature, to_xp
from crazyflow.utils import parametrize as _parametrize

if TYPE_CHECKING:
from types import ModuleType

from crazyflow._typing import Array # To be changed to array_api_typing later

P = ParamSpec("P")
R = TypeVar("R")


def parametrize(
fn: Callable[P, R], drone: str, xp: ModuleType | None = None, device: str | None = None
) -> Callable[P, R]:
"""Parametrize a controller function with the default controller parameters for a drone.

Args:
fn: The controller function to parametrize.
drone: The drone to use.
xp: The array API module to use. If not provided, numpy is used.
device: The device to use. If None, the device is inferred from the xp module.

Example:
```python
import numpy as np
from crazyflow.control import parametrize
from crazyflow.control.mellinger import state2attitude

ctrl = parametrize(state2attitude, "cf2x_L250")
pos = np.zeros(3)
quat = np.array([0.0, 0.0, 0.0, 1.0])
vel = np.zeros(3)
cmd = np.zeros(13)
rpyt, int_pos_err = ctrl(pos, quat, vel, cmd)
```

Returns:
The parametrized controller function with all keyword argument only parameters filled in.
"""
return _parametrize(fn, drone, load_params, xp=xp, device=device)


def load_params(
fn: Callable, drone: str, xp: ModuleType | None = None, device: str | None = None
) -> dict[str, Array]:
"""Load the parameters a specific controller function accepts.

Merges the ``"core"`` section with the function's ``[drone.<fn_name>]`` section (function values
take precedence), then keeps only the parameters in ``fn``'s signature.

Args:
fn: The controller function for which to load parameters.
drone: Name of the drone configuration, e.g. ``"cf2x_L250"``.
xp: The array API module to use. If not provided, numpy is used.
device: The device to use. If None, the device is inferred from the xp module.

Returns:
A flat dict mapping parameter names to arrays in the requested array namespace.
"""
assert isinstance(fn, Callable), f"Expected a function, got {type(fn)}"
controller = fn.__module__.split(".")[-2]
params_path = Path(__file__).parent / f"{controller}/params.toml"
if not params_path.exists():
raise KeyError(f"`{controller}` not found. Available controllers: {tuple(Control)}")
with open(params_path, "rb") as f:
params = tomllib.load(f)
if drone not in params:
raise KeyError(f"Drone `{drone}` not found in {controller}/params.toml")
merged = params[drone].get("core", {}) | params[drone].get(fn.__name__, {})
return to_xp(filter_to_signature(merged, fn), xp=xp, device=device)


class Control(str, Enum):
"""Control type of the simulated onboard controller."""

state = "state"
"""State control takes [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate].

Note:
Recommended frequency is >=20 Hz.

Warning:
Currently, we only use positions, velocities, and yaw. The rest of the state is ignored.
This is subject to change in the future.
"""
attitude = "attitude"
"""Attitude control takes [roll, pitch, yaw, collective thrust].

Note:
Recommended frequency is >=100 Hz.
"""
force_torque = "force_torque"
"""Force and torque control takes [fc, tx, ty, tz].

Note:
Recommended frequency is >=500 Hz.
"""
rotor_vel = "rotor_vel"
"""Rotor velocity control takes [w1, w2, w3, w4] in RPMs.

Note:
Recommended frequency is >=500 Hz.
"""
default = attitude


@jax.jit
def controllable(step: Array, freq: int, control_steps: Array, control_freq: int) -> Array:
"""Check which worlds can currently update their controllers.

Args:
step: The current step of the simulation.
freq: The frequency of the simulation.
control_steps: The steps at which the controllers were last updated.
control_freq: The frequency of the controllers.

Returns:
A boolean mask of shape (n_worlds,) that is True at the worlds where the controllers can be
updated.
"""
return ((step - control_steps) >= (freq / control_freq)) | (control_steps == -1)
Loading
Loading