diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 0a65728..0000000 --- a/.gitmodules +++ /dev/null @@ -1,8 +0,0 @@ -[submodule "submodules/drone-models"] - path = submodules/drone-models - url = https://github.com/learnsyslab/drone-models.git - branch = main -[submodule "submodules/drone-controllers"] - path = submodules/drone-controllers - url = https://github.com/learnsyslab/drone-controllers.git - branch = main diff --git a/README.md b/README.md index 236a415..badea95 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ -Crazyflow is a research simulator for quadrotors. It runs batched, differentiable simulations on CPU and GPU via JAX, with analytical and abstracted models for the Crazyflie 2.x family. +Crazyflow is a research simulator for quadrotors. It runs batched, differentiable simulations on CPU and GPU via JAX, with analytical and abstracted dynamics for the Crazyflie 2.x family. ```python import numpy as np @@ -36,10 +36,10 @@ for _ in range(100): ## Features - **n\_worlds x n\_drones** — batched over independent environments and multi-drone swarms simultaneously -- **GPU-accelerated** — up to 914 M steps/s on an RTX 4090 (first-principles physics, 262 K worlds) +- **GPU-accelerated** — up to 914 M steps/s on an RTX 4090 (first-principles dynamics, 262 K worlds) - **Differentiable** — `jax.grad` works through the full dynamics and control pipeline -- **First-principles models** — physics model using first-principles equations and parameters identified from real-world measurements -- **Abstracted models** — three physics models fitted from real Crazyflie flight data +- **First-principles dynamics** — dynamics using first-principles equations and parameters identified from real-world measurements +- **Abstracted dynamics** — simplified dynamics in three flavors fitted from real Crazyflie flight data - **Modular pipelines** — step and reset are tuples of plain JAX functions; insert anything, anywhere - **MuJoCo integration** — onscreen and offscreen rendering, raycasting, and contact detection via MJX @@ -60,7 +60,7 @@ pixi shell ## Performance -First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 4090. +First-principles dynamics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 4090. | n\_worlds | CPU steps/s | GPU steps/s | |---|---|---| @@ -72,17 +72,6 @@ First-principles physics, one drone. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 409 Full benchmarks including multi-drone scaling are in the [documentation](https://learnsyslab.github.io/crazyflow). -## Related packages - -Crazyflow is built on two companion packages that can also be used independently: - -| Package | Description | -|---|---| -| [drone-models](https://github.com/learnsyslab/drone-models) | Drone dynamics models (first-principles and fitted) compatible with NumPy, JAX, and PyTorch. Used by Crazyflow as the physics backend. | -| [drone-controllers](https://github.com/learnsyslab/drone-controllers) | Reference controller implementations including the Mellinger geometric controller. Used by Crazyflow to provide the state and attitude control modes. | - -Both are installed automatically as dependencies. For development, they are included as submodules in `submodules/` and installed in editable mode by the pixi environment. - ## Citation ```bibtex diff --git a/benchmark/compile.py b/benchmark/compile.py index 5dacb63..fb4141c 100644 --- a/benchmark/compile.py +++ b/benchmark/compile.py @@ -4,6 +4,8 @@ import jax from crazyflow import Sim +from crazyflow.control import Control +from crazyflow.dynamics import Dynamics def main(cache: bool = False): @@ -16,7 +18,7 @@ def main(cache: bool = False): # Time initialization start = time.perf_counter() - sim = Sim(n_worlds=1, n_drones=1, physics="sys_id", control="attitude") + sim = Sim(n_worlds=1, n_drones=1, dynamics=Dynamics.so_rpy, control=Control.attitude) init_time = time.perf_counter() - start # Time reset compilation @@ -29,7 +31,7 @@ def main(cache: bool = False): sim._step.lower(sim.data, 1).compile() step_time = time.perf_counter() - start - print(f"Simulation startup times | {sim.physics} | {sim.control}") + print(f"Simulation startup times | {sim.dynamics} | {sim.control}") print(f"Initialization: {init_time:.2f}s") print(f"Reset: {reset_time:.2f}s") print(f"Step: {step_time:.2f}s") diff --git a/benchmark/main.py b/benchmark/main.py index e5502a8..5191df1 100644 --- a/benchmark/main.py +++ b/benchmark/main.py @@ -55,7 +55,7 @@ def profile_gym_env_step( num_envs=sim_config.n_worlds, device=sim_config.device, freq=sim_config.freq, - physics=sim_config.physics, + dynamics=sim_config.dynamics, ) # Action for going up (in attitude control) @@ -151,7 +151,7 @@ def main(device: str = "cpu", n_worlds_exp: int = 6): sim_config = config_dict.ConfigDict() sim_config.n_worlds = 1 sim_config.n_drones = 1 - sim_config.physics = "first_principles" + sim_config.dynamics = "first_principles" sim_config.control = "attitude" sim_config.attitude_freq = 500 sim_config.device = device diff --git a/benchmark/op_count.py b/benchmark/op_count.py index 7e81019..b84444f 100644 --- a/benchmark/op_count.py +++ b/benchmark/op_count.py @@ -4,13 +4,12 @@ def main(): """Main entry point for profiling.""" - sim = Sim(n_worlds=1, n_drones=1, physics="first_principles", control="attitude") + sim = Sim(n_worlds=1, n_drones=1, dynamics="first_principles", control="attitude") compiled_reset = sim._reset.lower(sim.data, sim.default_data, None).compile() compiled_step = sim._step.lower(sim.data, 1).compile() - op_count_reset = compiled_reset.cost_analysis()["flops"] - op_count_step = compiled_step.cost_analysis()["flops"] - print(f"Op counts:\n Reset: {op_count_reset}\n Step: {op_count_step}") + print(f"Reset cost analysis: {compiled_reset.cost_analysis()}") + print(f"Step cost analysis: {compiled_step.cost_analysis()}") if __name__ == "__main__": diff --git a/benchmark/performance.py b/benchmark/performance.py index f824f88..1d7c49d 100644 --- a/benchmark/performance.py +++ b/benchmark/performance.py @@ -42,10 +42,13 @@ def profile_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str): def profile_gym_env_step(sim_config: config_dict.ConfigDict, n_steps: int, device: str): - device = jax.devices(device)[0] - envs: ReachPosEnv = gymnasium.make_vec( - "DroneReachPos-v0", time_horizon_in_seconds=2, num_envs=sim_config.n_worlds, **sim_config + "DroneReachPos-v0", + max_episode_time=10.0, + num_envs=sim_config.n_worlds, + dynamics=sim_config.dynamics, + freq=50, + device=device, ) # Action for going up (in attitude control) @@ -77,7 +80,7 @@ def main(): sim_config = config_dict.ConfigDict() sim_config.n_worlds = 1 sim_config.n_drones = 1 - sim_config.physics = "first_principles" + sim_config.dynamics = "first_principles" sim_config.control = "attitude" sim_config.device = device diff --git a/crazyflow/__init__.py b/crazyflow/__init__.py index a550d82..71c3de3 100644 --- a/crazyflow/__init__.py +++ b/crazyflow/__init__.py @@ -18,7 +18,9 @@ import crazyflow.envs # noqa: F401, ensure gymnasium envs are registered from crazyflow.control import Control -from crazyflow.sim import Physics, Sim +from crazyflow.drones import available_drones +from crazyflow.dynamics import Dynamics +from crazyflow.sim import Sim -__all__ = ["Sim", "Physics", "Control"] +__all__ = ["Sim", "Dynamics", "Control", "available_drones"] __version__ = "0.2.0" diff --git a/crazyflow/_typing.py b/crazyflow/_typing.py new file mode 100644 index 0000000..e0a899b --- /dev/null +++ b/crazyflow/_typing.py @@ -0,0 +1,8 @@ +"""This file is to be removed as soon as a proper typing is available by the official array-api.""" + +from typing import Any, TypeAlias + +import numpy.typing as npt + +Array: TypeAlias = Any # To be changed to array_api_typing later +ArrayLike: TypeAlias = Array | npt.ArrayLike diff --git a/crazyflow/control/__init__.py b/crazyflow/control/__init__.py index 51c8045..90ebc9e 100644 --- a/crazyflow/control/__init__.py +++ b/crazyflow/control/__init__.py @@ -1,3 +1,26 @@ -from crazyflow.control.control import Control +"""Implementations of onboard drone controllers in Python. -__all__ = ["Control"] +All controllers are implemented using the array API standard. This means that every controller is +agnostic to the choice of framework and supports e.g. NumPy, JAX, or PyTorch. We also implement all +controllers as pure functions to ensure that users can jit-compile them. All controllers use +broadcasting to support batching of arbitrary leading dimensions. + +We reimplement the onboard controller for two reasons: +- We cannot use the C++ bindings of the firmware to differentiate through the onboard controller. +- We need to implement it with JAX to enable efficient, batched computations. +""" + +from typing import Callable + +__all__ = [] + +from crazyflow.control.core import Control, parametrize +from crazyflow.control.mellinger import attitude2force_torque as mellinger_attitude2force_torque +from crazyflow.control.mellinger import state2attitude as mellinger_state2attitude + +available_controller: dict[str, Callable] = { + "mellinger_state2attitude": mellinger_state2attitude, + "mellinger_attitude2force_torque": mellinger_attitude2force_torque, +} + +__all__ = ["Control", "parametrize"] diff --git a/crazyflow/control/control.py b/crazyflow/control/control.py deleted file mode 100644 index 767dc6c..0000000 --- a/crazyflow/control/control.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Functional programming implementation of the onboard controller. - -We reimplement the onboard controller for two reasons: -- We cannot use the C++ bindings of the firmware to differentiate through the onboard controller. -- We need to implement it with JAX to enable efficient, batched computations. - -Since our controller is a PID controller, it requires integration of the error over time. We opt for -a functional implementation to avoid storing any state in the class. Doing so would either prevent -us from easily scaling across batches and drones with JAX's `vmap`, or require us to support batches -and multiple drones explicitly in the controller. -""" - -from enum import Enum - -import jax -from jax import Array - - -class Control(str, Enum): - """Control type of the simulated onboard controller.""" - - state = "state" - """State control takes [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]. - - Note: - Recommended frequency is >=20 Hz. - - Warning: - Currently, we only use positions, velocities, and yaw. The rest of the state is ignored. - This is subject to change in the future. - """ - attitude = "attitude" - """Attitude control takes [roll, pitch, yaw, collective thrust]. - - Note: - Recommended frequency is >=100 Hz. - """ - force_torque = "force_torque" - """Force and torque control takes [fc, tx, ty, tz]. - - Note: - Recommended frequency is >=500 Hz. - """ - rotor_vel = "rotor_vel" - """Rotor velocity control takes [w1, w2, w3, w4] in RPMs. - - Note: - Recommended frequency is >=500 Hz. - """ - default = attitude - - -@jax.jit -def controllable(step: Array, freq: int, control_steps: Array, control_freq: int) -> Array: - """Check which worlds can currently update their controllers. - - Args: - step: The current step of the simulation. - freq: The frequency of the simulation. - control_steps: The steps at which the controllers were last updated. - control_freq: The frequency of the controllers. - - Returns: - A boolean mask of shape (n_worlds,) that is True at the worlds where the controllers can be - updated. - """ - return ((step - control_steps) >= (freq / control_freq)) | (control_steps == -1) diff --git a/crazyflow/control/core.py b/crazyflow/control/core.py new file mode 100644 index 0000000..fde4393 --- /dev/null +++ b/crazyflow/control/core.py @@ -0,0 +1,134 @@ +"""Core functionalities for controller parametrization.""" + +from __future__ import annotations + +import tomllib +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Callable, ParamSpec, TypeVar + +import jax +from jax import Array + +from crazyflow.utils import filter_to_signature, to_xp +from crazyflow.utils import parametrize as _parametrize + +if TYPE_CHECKING: + from types import ModuleType + + from crazyflow._typing import Array # To be changed to array_api_typing later + +P = ParamSpec("P") +R = TypeVar("R") + + +def parametrize( + fn: Callable[P, R], drone: str, xp: ModuleType | None = None, device: str | None = None +) -> Callable[P, R]: + """Parametrize a controller function with the default controller parameters for a drone. + + Args: + fn: The controller function to parametrize. + drone: The drone to use. + xp: The array API module to use. If not provided, numpy is used. + device: The device to use. If None, the device is inferred from the xp module. + + Example: + ```python + import numpy as np + from crazyflow.control import parametrize + from crazyflow.control.mellinger import state2attitude + + ctrl = parametrize(state2attitude, "cf2x_L250") + pos = np.zeros(3) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + vel = np.zeros(3) + cmd = np.zeros(13) + rpyt, int_pos_err = ctrl(pos, quat, vel, cmd) + ``` + + Returns: + The parametrized controller function with all keyword argument only parameters filled in. + """ + return _parametrize(fn, drone, load_params, xp=xp, device=device) + + +def load_params( + fn: Callable, drone: str, xp: ModuleType | None = None, device: str | None = None +) -> dict[str, Array]: + """Load the parameters a specific controller function accepts. + + Merges the ``"core"`` section with the function's ``[drone.]`` section (function values + take precedence), then keeps only the parameters in ``fn``'s signature. + + Args: + fn: The controller function for which to load parameters. + drone: Name of the drone configuration, e.g. ``"cf2x_L250"``. + xp: The array API module to use. If not provided, numpy is used. + device: The device to use. If None, the device is inferred from the xp module. + + Returns: + A flat dict mapping parameter names to arrays in the requested array namespace. + """ + assert isinstance(fn, Callable), f"Expected a function, got {type(fn)}" + controller = fn.__module__.split(".")[-2] + params_path = Path(__file__).parent / f"{controller}/params.toml" + if not params_path.exists(): + raise KeyError(f"`{controller}` not found. Available controllers: {tuple(Control)}") + with open(params_path, "rb") as f: + params = tomllib.load(f) + if drone not in params: + raise KeyError(f"Drone `{drone}` not found in {controller}/params.toml") + merged = params[drone].get("core", {}) | params[drone].get(fn.__name__, {}) + return to_xp(filter_to_signature(merged, fn), xp=xp, device=device) + + +class Control(str, Enum): + """Control type of the simulated onboard controller.""" + + state = "state" + """State control takes [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]. + + Note: + Recommended frequency is >=20 Hz. + + Warning: + Currently, we only use positions, velocities, and yaw. The rest of the state is ignored. + This is subject to change in the future. + """ + attitude = "attitude" + """Attitude control takes [roll, pitch, yaw, collective thrust]. + + Note: + Recommended frequency is >=100 Hz. + """ + force_torque = "force_torque" + """Force and torque control takes [fc, tx, ty, tz]. + + Note: + Recommended frequency is >=500 Hz. + """ + rotor_vel = "rotor_vel" + """Rotor velocity control takes [w1, w2, w3, w4] in RPMs. + + Note: + Recommended frequency is >=500 Hz. + """ + default = attitude + + +@jax.jit +def controllable(step: Array, freq: int, control_steps: Array, control_freq: int) -> Array: + """Check which worlds can currently update their controllers. + + Args: + step: The current step of the simulation. + freq: The frequency of the simulation. + control_steps: The steps at which the controllers were last updated. + control_freq: The frequency of the controllers. + + Returns: + A boolean mask of shape (n_worlds,) that is True at the worlds where the controllers can be + updated. + """ + return ((step - control_steps) >= (freq / control_freq)) | (control_steps == -1) diff --git a/crazyflow/control/mellinger.py b/crazyflow/control/mellinger.py deleted file mode 100644 index 262a3f6..0000000 --- a/crazyflow/control/mellinger.py +++ /dev/null @@ -1,112 +0,0 @@ -from __future__ import annotations - -import jax.numpy as jnp -from drone_controllers.core import load_params -from drone_controllers.mellinger import ( - attitude2force_torque, - force_torque2rotor_vel, - state2attitude, -) -from flax.struct import dataclass, field -from jax import Array, Device - - -@dataclass -class MellingerStateData: - cmd: Array # (N, M, 13) - """Full state control command for the drone. - - A command consists of [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]. - We currently do not use the acceleration and angle rate components. This is subject to change. - """ - staged_cmd: Array # (N, M, 13) - """Staging buffer to store the most recent command until the next controller tick.""" - steps: Array # (N, 1) - """Last simulation steps that the state control command was applied.""" - freq: int = field(pytree_node=False) - """Frequency of the state control command.""" - pos_err_i: Array # (N, M, 3) - """Integral errors of the state control command.""" - # Parameters for the state controller - params: dict[str, Array] - - @staticmethod - def create( - n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device - ) -> MellingerStateData: - """Create a default set of state data for the simulation.""" - cmd = jnp.zeros((n_worlds, n_drones, 13), device=device) - steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device) - pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device) - params = load_params(state2attitude, drone_model, xp=jnp, device=device) - return MellingerStateData( - cmd=cmd, staged_cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params - ) - - -@dataclass -class MellingerAttitudeData: - cmd: Array # (N, M, 4) - """Full attitude control command for the drone. - - A command consists of [roll, pitch, yaw, collective thrust]. - """ - staged_cmd: Array # (N, M, 4) - """Staging buffer to store the most recent command until the next controller tick.""" - steps: Array # (N, 1) - """Last simulation steps that the attitude control command was applied.""" - freq: int = field(pytree_node=False) - """Frequency of the attitude control command.""" - r_int_error: Array # (N, M, 3) - """Integral errors of the attitude control command.""" - last_ang_vel: Array # (N, M, 3) - """Last angular velocity of the drone.""" - # Parameters for the attitude controller - params: dict[str, Array] - - @staticmethod - def create( - n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device - ) -> MellingerAttitudeData: - """Create a default set of attitude data for the simulation.""" - cmd = jnp.zeros((n_worlds, n_drones, 4), device=device) - steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device) - zeros_3d = jnp.zeros((n_worlds, n_drones, 3), device=device) - params = load_params(attitude2force_torque, drone_model, xp=jnp, device=device) - return MellingerAttitudeData( - cmd=cmd, - staged_cmd=cmd, - steps=steps, - freq=freq, - r_int_error=zeros_3d, - last_ang_vel=zeros_3d, - params=params, - ) - - -@dataclass -class MellingerForceTorqueData: - cmd: Array # (N, M, 4) - """Force-torque command for the drone. - - A command consists of [fz, tx, ty, tz]. - """ - staged_cmd: Array # (N, M, 4) - """Staging buffer to store the most recent command until the next controller tick.""" - steps: Array # (N, 1) - """Last simulation steps that the force and torque control command was applied.""" - freq: int = field(pytree_node=False) - """Frequency of the force and torque control command.""" - # Parameters for the force and torque controller - params: dict[str, Array] - - @staticmethod - def create( - n_worlds: int, n_drones: int, freq: int, drone_model: str, device: Device - ) -> MellingerForceTorqueData: - zero_4d = jnp.zeros((n_worlds, n_drones, 4), device=device) - steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device) - params = load_params(force_torque2rotor_vel, drone_model, xp=jnp, device=device) - return MellingerForceTorqueData( - cmd=zero_4d, staged_cmd=zero_4d, steps=steps, freq=freq, params=params - ) diff --git a/crazyflow/control/mellinger/__init__.py b/crazyflow/control/mellinger/__init__.py new file mode 100644 index 0000000..570613b --- /dev/null +++ b/crazyflow/control/mellinger/__init__.py @@ -0,0 +1,30 @@ +"""Mellinger controller reimplementation based on the Crazyflie firmware. + +See https://ieeexplore.ieee.org/document/5980409 for details. +""" + +from crazyflow.control.mellinger.control import ( + MellingerAttitudeData, + MellingerForceTorqueData, + MellingerStateData, + attitude2force_torque, + control_attitude2force_torque, + control_commit_attitude, + control_force_torque2rotor_vel, + control_state2attitude, + force_torque2rotor_vel, + state2attitude, +) + +__all__ = [ + "state2attitude", + "attitude2force_torque", + "force_torque2rotor_vel", + "MellingerStateData", + "MellingerAttitudeData", + "MellingerForceTorqueData", + "control_state2attitude", + "control_attitude2force_torque", + "control_commit_attitude", + "control_force_torque2rotor_vel", +] diff --git a/crazyflow/control/mellinger/control.py b/crazyflow/control/mellinger/control.py new file mode 100644 index 0000000..d0a1d56 --- /dev/null +++ b/crazyflow/control/mellinger/control.py @@ -0,0 +1,480 @@ +"""Mellinger controller reimplementation based on the Crazyflie firmware. + +The controller is split into three pure functions that form a pipeline: +``state2attitude`` → ``attitude2force_torque`` → ``force_torque2rotor_vel``. +Each stage can be used independently or chained together to produce per-motor +RPM commands from a full-state setpoint. + +Reference: D. Mellinger and V. Kumar, "Minimum snap trajectory generation and +control for quadrotors", ICRA 2011. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import array_api_extra as xpx +import jax.numpy as jnp +from array_api_compat import array_namespace +from flax.struct import dataclass, field +from scipy.spatial.transform import Rotation as R + +from crazyflow.control.core import controllable, load_params +from crazyflow.control.transform import force2pwm, motor_force2rotor_vel, pwm2force +from crazyflow.utils import leaf_replace + +if TYPE_CHECKING: + from jax import Device + + from crazyflow._typing import Array # To be changed to array_api_typing later + from crazyflow.sim.data import SimData + + +def state2attitude( + pos: Array, + quat: Array, + vel: Array, + cmd: Array, + pos_err_i: Array | None = None, + ctrl_freq: float = 100, + *, + mass: float, + kp: Array, + kd: Array, + ki: Array, + gravity_vec: Array, + mass_thrust: float, + int_err_max: Array, + thrust_max: float, + pwm_max: float, +) -> tuple[Array, Array]: + """Compute the positional part of the mellinger controller. + + All controllers are implemented as pure functions. Therefore, integral errors have to be passed + as an argument and returned as well. + + Args: + pos: Drone position with shape (..., 3). + quat: Drone orientation as xyzw quaternion with shape (..., 4). + vel: Drone velocity with shape (..., 3). + cmd: Full state command in SI units and rad with shape (..., 13). The entries are + [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]. + pos_err_i: Position integral error (..., 3) from the previous call. If None, it is + initialised to zero. + ctrl_freq: Control frequency in Hz + mass: Drone mass used for calculations in the controller in kg. + kp: Proportional gain for the position controller with shape (3,). + kd: Derivative gain for the position controller with shape (3,). + ki: Integral gain for the position controller with shape (3,). + gravity_vec: Gravity vector with shape (3,). We assume gravity to be in the negative z + direction. E.g., [0, 0, -9.81]. + mass_thrust: Conversion factor from thrust to PWM. + int_err_max: Range of the integral error with shape (3,). i_range in the firmware. + thrust_max: Maximum thrust in N. + pwm_max: Maximum PWM value. + + Returns: + The RPY collective thrust command [rad, rad, rad, N], and the integral error of the position + controller. + """ + xp = array_namespace(pos) + + setpoint_pos = cmd[..., 0:3] + setpoint_vel = cmd[..., 3:6] + setpoint_acc = cmd[..., 6:9] + setpoint_yaw = cmd[..., 9] + dt = 1 / ctrl_freq + # setpointRPY_rates = cmd[..., 10:13] + # From firmware controller_mellinger + pos_err = setpoint_pos - pos # l. 145 Position Error (ep) + vel_err = setpoint_vel - vel # l. 148 Velocity Error (ev) + # l.151 ff Integral Error + int_pos_err = xp.zeros_like(pos) if pos_err_i is None else pos_err_i + int_pos_err = xp.clip(int_pos_err + pos_err * dt, -int_err_max, int_err_max) + # l. 161 Desired thrust [F_des] + # => only one case here, since setpoint is always in absolute mode + # Note: since we've defined the gravity in z direction, a "-" needs to be added + target_thrust = ( + mass * (setpoint_acc - gravity_vec) + kp * pos_err + kd * vel_err + ki * int_pos_err + ) + # l. 178 Rate-controlled YAW is moving YAW angle setpoint + # => only one case here, since the setpoint is always in absolute mode + desired_yaw = setpoint_yaw + # l. 189 Z-Axis [zB] + rot = R.from_quat(quat).as_matrix() + z_axis = rot[..., -1] # 3rd column or roation matrix is z axis + # l. 194 yaw correction (only if position control is not used) + # => skipped since we always use position control here + + # l. 204 Current thrust [F] + # Taking the dot product of the last axis: + current_thrust = xp.vecdot(target_thrust, z_axis, axis=-1) + # l. 207 Calculate axis [zB_des] + z_axis_desired = target_thrust / xp.linalg.vector_norm(target_thrust) + # l. 210 [xC_des] + # x_axis_desired = z_axis_desired x [sin(yaw), cos(yaw), 0]^T + x_c_des_x = xp.cos(desired_yaw) + x_c_des_y = xp.sin(desired_yaw) + x_c_des_z = xp.zeros_like(x_c_des_x) + x_c_des = xp.stack((x_c_des_x, x_c_des_y, x_c_des_z), axis=-1) + # [yB_des] + y_axis_desired = xp.linalg.cross(z_axis_desired, x_c_des) + y_axis_desired = y_axis_desired / xp.linalg.vector_norm(y_axis_desired) + # [xB_des] + x_axis_desired = xp.linalg.cross(y_axis_desired, z_axis_desired) + # converting desired axis to rotation matrix and then to RPY. + matrix = xp.stack((x_axis_desired, y_axis_desired, z_axis_desired), axis=-1) + # l. 220 [eR] The mellinger controller now continues with the attitude controller. However, we + # decouple the attitude controller from the state controller. We therefore stop here and + # continue the computation in the attitude2force_torque controller. The conversion to RPY is + # necessary to pass the command to the attitude2force_torque controller in the correct format. + command_RPY = R.from_matrix(matrix).as_euler("xyz", degrees=False) + # l. 283 [control_thrust] + # The firmware returns thrust in PWM, but we want to stay in SI units. The conversion from + # thrust to PWM uses a mass_thrust parameter, which is a constant converting thrust values to + # PWMs. This transformation changes the thrust value, because it is fixed to a specific value + # instead of dynamically scaling with the mass parameter of the controller! Hence, we include + # this conversion here and thus effectively rescale the thrust slightly. The conversion below + # maps thrust -> PWM -> rescaled thrust. + thrust = pwm2force(mass_thrust * current_thrust, thrust_max * 4, pwm_max) + command_rpyt = xp.concat((command_RPY, thrust[..., None]), axis=-1) + return command_rpyt, int_pos_err + + +def attitude2force_torque( + quat: Array, + ang_vel: Array, + cmd: Array, + prev_ang_vel: Array | None = None, + r_int_error: Array | None = None, + ctrl_freq: int = 500, + *, + kR: Array, + kw: Array, + ki_m: Array, + kd_omega: Array, + int_err_max: Array, + torque_pwm_max: Array, + thrust_max: float, + pwm_min: float, + pwm_max: float, + L: float, + thrust2torque: float, + mixing_matrix: Array, +) -> tuple[Array, Array, Array]: + """Compute the attitude to desired force-torque part of the Mellinger controller. + + Note: + We omit the axis flip in the firmware as it has only been introduced to make the controller + compatible with the new frame of the Crazyflie 2.1. + + Args: + quat: Drone orientation as xyzw quaternion with shape (..., 4). + ang_vel: Drone angular drone velocity in rad/s with shape (..., 3). + cmd: Commanded attitude (roll, pitch, yaw) and total thrust [rad, rad, rad, N]. + r_int_error: Angular velocity integral error (..., 3) from the previous call. If None, it + is initialised to zero. + ctrl_freq: Control frequency in Hz + kR: Proportional gain for the rotation error with shape (3,). + kw: Proportional gain for the angular velocity error with shape (3,). + ki_m: Integral gain for the rotation error with shape (3,). + kd_omega: Derivative gain for the angular velocity error with shape (3,). + int_err_max: Range of the integral error with shape (3,). i_range in the firmware. + torque_pwm_max: Maximum torque in PWM. + thrust_max: Maximum thrust in N. + pwm_min: Minimum PWM value. + pwm_max: Maximum PWM value. + prev_ang_vel: Previous angular velocity in rad/s. + L: Distance from the center of the quadrotor to the center of the rotor in m. + thrust2torque: Conversion factor (m). + mixing_matrix: Mixing matrix for the motor forces with shape (4, 3). + + Returns: + 4 Motor forces [N], i_error_m + """ + xp = array_namespace(quat) + force_des = cmd[..., 3] # Total thrust in N + rpy_des = cmd[..., :3] + dt = 1 / ctrl_freq + # l. 220 ff [eR]. We're using the "inefficient" code path from the firmware + rot = R.from_quat(quat) + rot_des = R.from_euler("xyz", rpy_des, degrees=False) + # Equivalent to eRM = R_des.T @ R_act - R_act.T @ R_des + # Firmware does not multiply by 0.5 here, but the original paper does. We replicate the firmware + # exactly to avoid sim2real issues with the original controller parameters. + R_delta = (rot_des.inv() * rot).as_matrix() + eRM = R_delta - R_delta.mT + # Vee operator (SO3 to R3) + eR = xp.stack((eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1) + # l.248 ff [ew] + # Warning: We assume zero desired angular velocity + ang_vel_des = xp.zeros_like(ang_vel) + prev_ang_vel_des = xp.zeros_like(ang_vel) + ew = ang_vel_des - ang_vel + # WARNING: if the setpoint is ever != 0 => change sign of ew.y! + + # l.259 ff [err_d_rpy] + prev_ang_vel = xp.zeros_like(ang_vel) if prev_ang_vel is None else prev_ang_vel + ang_vel_d_err = ((ang_vel_des - prev_ang_vel_des) - (ang_vel - prev_ang_vel)) / dt + # l.281: No err_d_yaw + ang_vel_d_err = xpx.at(ang_vel_d_err)[..., 2].set(0) + + # l. 268 ff Integral Error + r_int_error = xp.zeros_like(ang_vel) if r_int_error is None else r_int_error + r_int_error = r_int_error - eR * dt + r_int_error = xp.clip(r_int_error, -int_err_max, int_err_max) + # l. 278 ff Moment: + torque_pwm = -kR * eR + kw * ew + ki_m * r_int_error + kd_omega * ang_vel_d_err + # l. 297 ff + torque_pwm = xp.clip(torque_pwm, -torque_pwm_max, torque_pwm_max) + torque_pwm = xp.where((force_des > 0)[..., None], torque_pwm, 0.0) + force_des_pwm = force2pwm(force_des / 4, thrust_max, pwm_max) + pwms = force_torque_pwms2pwms(force_des_pwm, torque_pwm, mixing_matrix) + pwms = xp.where(xp.all(pwms == 0), 0.0, xp.clip(pwms, pwm_min, pwm_max)) + + # Info: The Mellinger controller in the firmware ends here. However, we enforce a standardized + # interface in the simulation from states -> attitude -> force_torque. We therefore need this + # function to convert from PWMs to forces and torques. + # In the firmware, this is done implicitly with the motor mixing. We therefore do the motor + # mixing here, calculate the resulting force and torque, and return them. + # This process is then reversed in the next step, where we recover the desired motor forces from + # the force and torque. + motor_forces = pwm2force(pwms, thrust_max, pwm_max) + # TODO: Long-term, the Mellinger controller should use the new power distribution which + # calculates motor forces in Newtons. However, for now the firmware uses the legacy power + # distribution, so we keep it here for compatibility. To have a single consistent interface for + # controllers, we still want to return SI forces and torques. We thus need to convert the legacy + # output to SI units. + # l. 310 ff + torque_des = (mixing_matrix @ motor_forces[..., None])[..., 0] * xp.stack([L, L, thrust2torque]) + force_des = xp.sum(motor_forces, axis=-1)[..., None] + return force_des, torque_des, r_int_error + + +def force_torque_pwms2pwms(force_pwm: Array, torque_pwm: Array, mixing_matrix: Array) -> Array: + """Convert desired collective thrust and torques to rotor speeds using legacy behavior.""" + xp = array_namespace(force_pwm) + torque_pwm = xp.concatenate((torque_pwm[..., :2] / 2, torque_pwm[..., 2:]), axis=-1) + return force_pwm[..., None] + (torque_pwm @ mixing_matrix) + + +def force_torque2rotor_vel( + force: Array, + torque: Array, + *, + thrust_min: float, + thrust_max: float, + L: float, + rpm2thrust: Array, + thrust2torque: float, + mixing_matrix: Array, +) -> Array: + """Convert desired collective thrust and torques to rotor speeds. + + The firmware calculates PWMs for each motor, compensates for the battery voltage, and then + applies the modified PWMs to the motors. We assume perfect battery compensation here, skip the + PWM interface except for clipping, and instead return desired motor forces. + + Note: + The equivalent function in the crazyflie firmware is power_distribution from + power_distribution_quadrotor.c. + + Warning: + This function assumes an X rotor configuration. + + Args: + force: Desired thrust in SI units with shape (...,). + torque: Desired torque in SI units with shape (..., 3). + thrust_min: Minimum thrust in N. + thrust_max: Maximum thrust in N. + L: Distance from the center of the quadrotor to the center of the rotor in m. + rpm2thrust: Force constants (N/RPM, N/RPM**2). + thrust2torque: Conversion factor (m). + mixing_matrix: Mixing matrix for the motor forces with shape (4, 3). + + Returns: + The desired motor forces in SI units with shape (..., 4). + """ + xp = array_namespace(torque) + assert torque.shape[-1] == 3, f"Torque must have shape (..., 3), but has {torque.shape}" + assert force.shape[-1] == 1, f"Force must have shape (..., 1), but has {force.shape}" + torque_forces = (torque * xp.asarray([1 / L, 1 / L, 1 / thrust2torque])) @ mixing_matrix + motor_forces = (torque_forces + force) / 4 + # Clip motor forces on the thrust instead of PWM level. + motor_forces = xp.where(xp.all(force == 0), 0.0, xp.clip(motor_forces, thrust_min, thrust_max)) + # Assume perfect battery compensation and calculate the desired motor speeds directly + return motor_force2rotor_vel(motor_forces, rpm2thrust) + + +@dataclass +class MellingerStateData: + cmd: Array # (N, M, 13) + """Full state control command for the drone. + + A command consists of [x, y, z, vx, vy, vz, ax, ay, az, yaw, roll_rate, pitch_rate, yaw_rate]. + We currently do not use the acceleration and angle rate components. This is subject to change. + """ + staged_cmd: Array # (N, M, 13) + """Staging buffer to store the most recent command until the next controller tick.""" + steps: Array # (N, 1) + """Last simulation steps that the state control command was applied.""" + freq: int = field(pytree_node=False) + """Frequency of the state control command.""" + pos_err_i: Array # (N, M, 3) + """Integral errors of the state control command.""" + # Parameters for the state controller + params: dict[str, Array] + + @staticmethod + def create( + n_worlds: int, n_drones: int, freq: int, drone: str, device: Device + ) -> MellingerStateData: + """Create a default set of state data for the simulation.""" + cmd = jnp.zeros((n_worlds, n_drones, 13), device=device) + steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device) + pos_err_i = jnp.zeros((n_worlds, n_drones, 3), device=device) + params = load_params(state2attitude, drone, xp=jnp, device=device) + return MellingerStateData( + cmd=cmd, staged_cmd=cmd, steps=steps, freq=freq, pos_err_i=pos_err_i, params=params + ) + + +@dataclass +class MellingerAttitudeData: + cmd: Array # (N, M, 4) + """Full attitude control command for the drone. + + A command consists of [roll, pitch, yaw, collective thrust]. + """ + staged_cmd: Array # (N, M, 4) + """Staging buffer to store the most recent command until the next controller tick.""" + steps: Array # (N, 1) + """Last simulation steps that the attitude control command was applied.""" + freq: int = field(pytree_node=False) + """Frequency of the attitude control command.""" + r_int_error: Array # (N, M, 3) + """Integral errors of the attitude control command.""" + last_ang_vel: Array # (N, M, 3) + """Last angular velocity of the drone.""" + # Parameters for the attitude controller + params: dict[str, Array] + + @staticmethod + def create( + n_worlds: int, n_drones: int, freq: int, drone: str, device: Device + ) -> MellingerAttitudeData: + """Create a default set of attitude data for the simulation.""" + cmd = jnp.zeros((n_worlds, n_drones, 4), device=device) + steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device) + zeros_3d = jnp.zeros((n_worlds, n_drones, 3), device=device) + params = load_params(attitude2force_torque, drone, xp=jnp, device=device) + return MellingerAttitudeData( + cmd=cmd, + staged_cmd=cmd, + steps=steps, + freq=freq, + r_int_error=zeros_3d, + last_ang_vel=zeros_3d, + params=params, + ) + + +@dataclass +class MellingerForceTorqueData: + cmd: Array # (N, M, 4) + """Force-torque command for the drone. + + A command consists of [fz, tx, ty, tz]. + """ + staged_cmd: Array # (N, M, 4) + """Staging buffer to store the most recent command until the next controller tick.""" + steps: Array # (N, 1) + """Last simulation steps that the force and torque control command was applied.""" + freq: int = field(pytree_node=False) + """Frequency of the force and torque control command.""" + # Parameters for the force and torque controller + params: dict[str, Array] + + @staticmethod + def create( + n_worlds: int, n_drones: int, freq: int, drone: str, device: Device + ) -> MellingerForceTorqueData: + zero_4d = jnp.zeros((n_worlds, n_drones, 4), device=device) + steps = -jnp.ones((n_worlds, 1), dtype=jnp.int32, device=device) + params = load_params(force_torque2rotor_vel, drone, xp=jnp, device=device) + return MellingerForceTorqueData( + cmd=zero_4d, staged_cmd=zero_4d, steps=steps, freq=freq, params=params + ) + + +def control_state2attitude(data: SimData) -> SimData: + """Compute the updated controls for the state controller.""" + states = data.states + state_ctrl: MellingerStateData = data.controls.state + assert state_ctrl is not None, "Using state controller without initialized data" + mask = controllable(data.core.steps, data.core.freq, state_ctrl.steps, state_ctrl.freq) + state_ctrl = leaf_replace(state_ctrl, mask, cmd=state_ctrl.staged_cmd) + rpyt, pos_err_i = state2attitude( + states.pos, + states.quat, + states.vel, + state_ctrl.cmd, + pos_err_i=state_ctrl.pos_err_i, + ctrl_freq=state_ctrl.freq, + **state_ctrl.params, + ) + state_ctrl = leaf_replace(state_ctrl, mask, steps=data.core.steps, pos_err_i=pos_err_i) + attitude_ctrl = leaf_replace(data.controls.attitude, mask, staged_cmd=rpyt) + return data.replace(controls=data.controls.replace(state=state_ctrl, attitude=attitude_ctrl)) + + +def control_attitude2force_torque(data: SimData) -> SimData: + """Compute the updated controls for the attitude controller.""" + states = data.states + attitude_ctrl: MellingerAttitudeData = data.controls.attitude + assert attitude_ctrl is not None, "Using attitude controller without initialized data" + mask = controllable(data.core.steps, data.core.freq, attitude_ctrl.steps, attitude_ctrl.freq) + attitude_ctrl = leaf_replace(attitude_ctrl, mask, cmd=attitude_ctrl.staged_cmd) + force, torque, r_int_error = attitude2force_torque( + states.quat, + states.ang_vel, + attitude_ctrl.cmd, + r_int_error=attitude_ctrl.r_int_error, + ctrl_freq=attitude_ctrl.freq, + prev_ang_vel=attitude_ctrl.last_ang_vel, + **attitude_ctrl.params, + ) + attitude_ctrl = leaf_replace( + attitude_ctrl, + mask, + r_int_error=r_int_error, + last_ang_vel=states.ang_vel, + steps=data.core.steps, + ) + ft_ctrl = leaf_replace( + data.controls.force_torque, mask, staged_cmd=jnp.concat([force, torque], axis=-1) + ) + return data.replace( + states=states, controls=data.controls.replace(attitude=attitude_ctrl, force_torque=ft_ctrl) + ) + + +def control_commit_attitude(data: SimData) -> SimData: + """Commit the staged attitude command to the controller setpoint.""" + attitude_ctrl: MellingerAttitudeData = data.controls.attitude + mask = controllable(data.core.steps, data.core.freq, attitude_ctrl.steps, attitude_ctrl.freq) + attitude_ctrl = leaf_replace(attitude_ctrl, mask, cmd=attitude_ctrl.staged_cmd) + return data.replace(controls=data.controls.replace(attitude=attitude_ctrl)) + + +def control_force_torque2rotor_vel(data: SimData) -> SimData: + """Compute the updated controls for the thrust controller.""" + ft_ctrl: MellingerForceTorqueData = data.controls.force_torque + assert ft_ctrl is not None, "Using force torque controller without initialized data" + mask = controllable(data.core.steps, data.core.freq, ft_ctrl.steps, ft_ctrl.freq) + ft_ctrl = leaf_replace(ft_ctrl, mask, cmd=ft_ctrl.staged_cmd) + rotor_vel = force_torque2rotor_vel( + ft_ctrl.cmd[..., [0]], ft_ctrl.cmd[..., 1:], **ft_ctrl.params + ) + ft_ctrl = leaf_replace(ft_ctrl, mask, steps=data.core.steps) + return data.replace(controls=data.controls.replace(rotor_vel=rotor_vel, force_torque=ft_ctrl)) diff --git a/crazyflow/control/mellinger/params.toml b/crazyflow/control/mellinger/params.toml new file mode 100644 index 0000000..9589be8 --- /dev/null +++ b/crazyflow/control/mellinger/params.toml @@ -0,0 +1,131 @@ +[cf2x_L250] +[cf2x_L250.core] +mass = 0.029 # The controller is using the wrong mass by default +mass_thrust = 132000 +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.012817578393224994 # in N per motor +thrust_max = 0.12 # in N per motor +torque_pwm_max = [32000.0, 32000.0, 32000.0] +L = 0.03253 +rpm2thrust = [0.0, -5.382196214637237e-7, 2.4582929831265485e-10] +rpm2torque = [0.0, 1.410454111996297e-9, 1.4592584373980652e-12] +thrust2torque = 0.007350862856566459 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +gravity_vec = [0.0, 0.0, -9.81] + +[cf2x_L250.state2attitude] +kp = [0.4, 0.4, 1.25] +kd = [0.2, 0.2, 0.5] +ki = [0.05, 0.05, 0.05] +int_err_max = [2.0, 2.0, 0.4] + +[cf2x_L250.attitude2force_torque] +kR = [70000.0, 70000.0, 60000.0] +kw = [20000.0, 20000.0, 12000.0] +ki_m = [0.0, 0.0, 500.0] +kd_omega = [200.0, 200.0, 0.0] +int_err_max = [1.0, 1.0, 1500.0] + +[cf2x_P250] +[cf2x_P250.core] +mass = 0.029 # The controller is using the wrong mass by default +mass_thrust = 132000 +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.012817578393224994 # in N per motor +thrust_max = 0.12 # in N per motor +torque_pwm_max = [32000.0, 32000.0, 32000.0] +L = 0.03253 +rpm2thrust = [0.0, -3.6200226530383495e-7, 1.6060924304100328e-10] # Index is order +rpm2torque = [0.0, -2.2665265829562245e-9, 1.1149485566919186e-12] # Index is order +thrust2torque = 0.0069928948992470565 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +gravity_vec = [0.0, 0.0, -9.81] + +[cf2x_P250.state2attitude] +kp = [0.4, 0.4, 1.25] +kd = [0.2, 0.2, 0.5] +ki = [0.05, 0.05, 0.05] +int_err_max = [2.0, 2.0, 0.4] + +[cf2x_P250.attitude2force_torque] +kR = [70000.0, 70000.0, 60000.0] +kw = [20000.0, 20000.0, 12000.0] +ki_m = [0.0, 0.0, 500.0] +kd_omega = [200.0, 200.0, 0.0] +int_err_max = [1.0, 1.0, 1500.0] + +[cf2x_T350] +[cf2x_T350.core] +mass = 0.0325 # The controller is using the wrong mass by default +mass_thrust = 132000 +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.01922636758983749 # in N per motor +thrust_max = 0.18 # in N per motor +torque_pwm_max = [32000.0, 32000.0, 32000.0] +L = 0.03253 +rpm2thrust = [0.0, -7.167227176573658e-7, 2.9401303690194613e-10] # Index is order +rpm2torque = [0.0, 5.815894847811497e-10, 1.331813874166509e-12] # Index is order +thrust2torque = 0.005355990836477486 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +gravity_vec = [0.0, 0.0, -9.81] + +[cf2x_T350.state2attitude] +kp = [0.4, 0.4, 1.25] +kd = [0.2, 0.2, 0.5] +ki = [0.05, 0.05, 0.05] +int_err_max = [2.0, 2.0, 0.4] + +[cf2x_T350.attitude2force_torque] +kR = [70000.0, 70000.0, 60000.0] +kw = [20000.0, 20000.0, 12000.0] +ki_m = [0.0, 0.0, 500.0] +kd_omega = [200.0, 200.0, 0.0] +int_err_max = [1.0, 1.0, 1500.0] + +[cf21B_500] +[cf21B_500.core] +gravity_vec = [0.0, 0.0, -9.81] +mass = 0.0393 # The controller is using the wrong mass by default +L = 0.035355 +rpm2thrust = [0.0, -3.133427287299859e-7, 4.407354891648379e-10] +rpm2torque = [0.0, 1.65886356219615e-9, 2.4693477924534137e-12] +thrust2torque = 0.00593893393599368 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +pwm_min = 7000.0 +pwm_max = 65535.0 +thrust_min = 0.02136263065537499 # in N per motor +thrust_max = 0.2 # in N per motor +mass_thrust = 132000 +torque_pwm_max = [32000.0, 32000.0, 32000.0] + +[cf21B_500.state2attitude] +kp = [0.4, 0.4, 1.25] +kd = [0.2, 0.2, 0.5] +ki = [0.05, 0.05, 0.05] +int_err_max = [2.0, 2.0, 0.4] + +[cf21B_500.attitude2force_torque] +kR = [70000.0, 70000.0, 60000.0] +kw = [20000.0, 20000.0, 12000.0] +ki_m = [0.0, 0.0, 500.0] +kd_omega = [200.0, 200.0, 0.0] +int_err_max = [1.0, 1.0, 1500.0] diff --git a/crazyflow/control/transform.py b/crazyflow/control/transform.py new file mode 100644 index 0000000..ee1484b --- /dev/null +++ b/crazyflow/control/transform.py @@ -0,0 +1,60 @@ +"""Transformations between physical parameters of the quadrotors. + +Bundles conversions between motor forces, rotor velocities, and PWM commands. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from array_api_compat import array_namespace + +if TYPE_CHECKING: + from crazyflow._typing import Array # To be changed to array_api_typing later + + +def motor_force2rotor_vel(motor_forces: Array, rpm2thrust: Array) -> Array: + """Convert motor forces to rotor velocities, where f=a*rpm^2+b*rpm+c. + + Args: + motor_forces: Motor forces in SI units with shape (..., N). + rpm2thrust: RPM to thrust conversion factors. + + Returns: + Array of rotor velocities in rad/s with shape (..., N). + """ + xp = array_namespace(motor_forces) + return ( + -rpm2thrust[1] + + xp.sqrt(rpm2thrust[1] ** 2 - 4 * rpm2thrust[2] * (rpm2thrust[0] - motor_forces)) + ) / (2 * rpm2thrust[2]) + + +def force2pwm(thrust: Array | float, thrust_max: Array | float, pwm_max: Array | float) -> Array: + """Convert thrust in N to thrust in PWM. + + Args: + thrust: Array or float of the thrust in [N] + thrust_max: Maximum thrust in [N] + pwm_max: Maximum PWM value + + Returns: + Thrust converted in PWM. + """ + return thrust / thrust_max * pwm_max + + +def pwm2force( + pwm: Array | float, thrust_max: Array | float, pwm_max: Array | float +) -> Array | float: + """Convert pwm thrust command to actual thrust. + + Args: + pwm: Array or float of the pwm value + thrust_max: Maximum thrust in [N] + pwm_max: Maximum PWM value + + Returns: + thrust: Array or float thrust in [N] + """ + return pwm / pwm_max * thrust_max diff --git a/crazyflow/drones/__init__.py b/crazyflow/drones/__init__.py new file mode 100644 index 0000000..5c686bb --- /dev/null +++ b/crazyflow/drones/__init__.py @@ -0,0 +1,37 @@ +"""Hardware descriptions for the supported drone platforms. + +This package bundles the physical assets that define each drone configuration: the MuJoCo MJCF scene +files, their referenced meshes (``assets/``), and the physical parameters shared across all dynamics +(``params.toml`` with mass, inertia, thrust and torque curves, …). These describe the *hardware* and +are independent of the dynamics formulation used to simulate it (see [crazyflow.dynamics][]). + +Use ``available_drones`` to enumerate the supported configurations, and ``load_params`` to read all +physical parameters of a drone. +""" + +import tomllib +from pathlib import Path + +# Currently supported platforms: +# * **cf2x_L250** — Crazyflie 2.x +# * **cf2x_P250** — Crazyflie 2.x with plus propellers +# * **cf2x_T350** — Crazyflie 2.x with thrust upgrade kit +# * **cf21B_500** — Crazyflie 2.1 Brushless with 500 mAh battery +available_drones: tuple[str, ...] = ("cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500") + +__all__ = ["available_drones", "load_params"] + + +def load_params(drone: str) -> dict: + """Load all physical parameters of a drone from ``params.toml``. + + Returns the raw values (lists/scalars) for the whole drone. + + Args: + drone: Name of the drone configuration, e.g. ``"cf2x_L250"``. + """ + with open(Path(__file__).parent / "params.toml", "rb") as f: + params = tomllib.load(f) + if drone not in params or drone not in available_drones: + raise KeyError(f"Drone `{drone}` not found in drones/params.toml") + return params[drone] diff --git a/crazyflow/drones/assets/cf21B/cf21B_PropL.stl b/crazyflow/drones/assets/cf21B/cf21B_PropL.stl new file mode 100755 index 0000000..e494937 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_PropL.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_PropR.stl b/crazyflow/drones/assets/cf21B/cf21B_PropR.stl new file mode 100755 index 0000000..fbb1eca Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_PropR.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_battery-holder.stl b/crazyflow/drones/assets/cf21B/cf21B_battery-holder.stl new file mode 100755 index 0000000..c65b953 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_battery-holder.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_battery.stl b/crazyflow/drones/assets/cf21B/cf21B_battery.stl new file mode 100755 index 0000000..24fc6a4 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_battery.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_connector-pins.stl b/crazyflow/drones/assets/cf21B/cf21B_connector-pins.stl new file mode 100755 index 0000000..da6effa Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_connector-pins.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_connectors.stl b/crazyflow/drones/assets/cf21B/cf21B_connectors.stl new file mode 100755 index 0000000..a83b448 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_connectors.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_full.stl b/crazyflow/drones/assets/cf21B/cf21B_full.stl new file mode 100755 index 0000000..37b43ef Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_full.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_fused.stl b/crazyflow/drones/assets/cf21B/cf21B_fused.stl new file mode 100644 index 0000000..8ff96df Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_fused.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_header.stl b/crazyflow/drones/assets/cf21B/cf21B_header.stl new file mode 100755 index 0000000..3657a75 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_header.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_motors.stl b/crazyflow/drones/assets/cf21B/cf21B_motors.stl new file mode 100755 index 0000000..a040f84 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_motors.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_no-prop.stl b/crazyflow/drones/assets/cf21B/cf21B_no-prop.stl new file mode 100755 index 0000000..e5efbe5 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_no-prop.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_pcb.stl b/crazyflow/drones/assets/cf21B/cf21B_pcb.stl new file mode 100755 index 0000000..0fb0512 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_pcb.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf21B_prop-guards.stl b/crazyflow/drones/assets/cf21B/cf21B_prop-guards.stl new file mode 100755 index 0000000..d9cf5de Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf21B_prop-guards.stl differ diff --git a/crazyflow/drones/assets/cf21B/cf_led-diffuser.stl b/crazyflow/drones/assets/cf21B/cf_led-diffuser.stl new file mode 100644 index 0000000..e030a62 Binary files /dev/null and b/crazyflow/drones/assets/cf21B/cf_led-diffuser.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xL_PropL.stl b/crazyflow/drones/assets/cf2x/cf2xL_PropL.stl new file mode 100755 index 0000000..4252e68 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xL_PropL.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xL_PropR.stl b/crazyflow/drones/assets/cf2x/cf2xL_PropR.stl new file mode 100755 index 0000000..1cbf570 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xL_PropR.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xL_fused.stl b/crazyflow/drones/assets/cf2x/cf2xL_fused.stl new file mode 100644 index 0000000..c26fceb Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xL_fused.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xL_motors.stl b/crazyflow/drones/assets/cf2x/cf2xL_motors.stl new file mode 100755 index 0000000..0979db2 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xL_motors.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xP_PropL.stl b/crazyflow/drones/assets/cf2x/cf2xP_PropL.stl new file mode 100755 index 0000000..7cc9b8e Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xP_PropL.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xP_PropR.stl b/crazyflow/drones/assets/cf2x/cf2xP_PropR.stl new file mode 100755 index 0000000..813f0d5 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xP_PropR.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xP_fused.stl b/crazyflow/drones/assets/cf2x/cf2xP_fused.stl new file mode 100644 index 0000000..ab3624f Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xP_fused.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xP_motors.stl b/crazyflow/drones/assets/cf2x/cf2xP_motors.stl new file mode 100755 index 0000000..0979db2 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xP_motors.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xT_fused.stl b/crazyflow/drones/assets/cf2x/cf2xT_fused.stl new file mode 100644 index 0000000..48c3071 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xT_fused.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2xT_motors.stl b/crazyflow/drones/assets/cf2x/cf2xT_motors.stl new file mode 100755 index 0000000..430fe16 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2xT_motors.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2x_battery-holder.stl b/crazyflow/drones/assets/cf2x/cf2x_battery-holder.stl new file mode 100755 index 0000000..6276ad5 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2x_battery-holder.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2x_battery.stl b/crazyflow/drones/assets/cf2x/cf2x_battery.stl new file mode 100755 index 0000000..989f9f8 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2x_battery.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2x_connector-pins.stl b/crazyflow/drones/assets/cf2x/cf2x_connector-pins.stl new file mode 100755 index 0000000..476972d Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2x_connector-pins.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2x_connectors.stl b/crazyflow/drones/assets/cf2x/cf2x_connectors.stl new file mode 100755 index 0000000..3364f70 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2x_connectors.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2x_motor-holder.stl b/crazyflow/drones/assets/cf2x/cf2x_motor-holder.stl new file mode 100755 index 0000000..65729ab Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2x_motor-holder.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf2x_pcb.stl b/crazyflow/drones/assets/cf2x/cf2x_pcb.stl new file mode 100755 index 0000000..bfb0677 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf2x_pcb.stl differ diff --git a/crazyflow/drones/assets/cf2x/cf_led-diffuser.stl b/crazyflow/drones/assets/cf2x/cf_led-diffuser.stl new file mode 100644 index 0000000..e030a62 Binary files /dev/null and b/crazyflow/drones/assets/cf2x/cf_led-diffuser.stl differ diff --git a/crazyflow/drones/cf21B_500.xml b/crazyflow/drones/cf21B_500.xml new file mode 100644 index 0000000..c20932f --- /dev/null +++ b/crazyflow/drones/cf21B_500.xml @@ -0,0 +1,93 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/crazyflow/drones/cf2x_L250.xml b/crazyflow/drones/cf2x_L250.xml new file mode 100644 index 0000000..47396b0 --- /dev/null +++ b/crazyflow/drones/cf2x_L250.xml @@ -0,0 +1,94 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/crazyflow/drones/cf2x_P250.xml b/crazyflow/drones/cf2x_P250.xml new file mode 100644 index 0000000..6a7fa83 --- /dev/null +++ b/crazyflow/drones/cf2x_P250.xml @@ -0,0 +1,93 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/crazyflow/drones/cf2x_T350.xml b/crazyflow/drones/cf2x_T350.xml new file mode 100644 index 0000000..1378505 --- /dev/null +++ b/crazyflow/drones/cf2x_T350.xml @@ -0,0 +1,94 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/crazyflow/drones/params.toml b/crazyflow/drones/params.toml new file mode 100644 index 0000000..a8119bd --- /dev/null +++ b/crazyflow/drones/params.toml @@ -0,0 +1,152 @@ +# This file contains all the **physical** parameters for all drones. + +[cf2x_L250] +gravity_vec = [0.0, 0.0, -9.81] +mass = 0.0319 +L = 0.03253 +J = [ # TODO + [16.8e-6, 0.0, 0.0], + [0.0, 16.8e-6, 0.0], + [0.0, 0.0, 29.8e-6] +] +rpm2thrust = [0.0, -5.382196214637237e-7, 2.4582929831265485e-10] +rpm2torque = [0.0, 1.410454111996297e-9, 1.4592584373980652e-12] +thrust2torque = 0.007350862856566459 +rotor_dyn_coef = [ 7.355623702172756, 0.0, 0.0, 0.00024443862952110715,] # [ 7.355623702172756, 0.0, 0.0, 0.00024443862952110715,] +rotor_dyn_coef_simple = 6.886705423469015 +thrust_dyn_coef = 6.8932095506763345 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +drag_matrix = [ # This term is from the so_rpy_rotor_drag dynamics + [-0.01471782, 0.0, 0.0 ], + [0.0, -0.01471782, 0.0 ], + [0.0, 0.0, -0.01277641 ] +] +# The following parameters are for the platform, but are maybe not actually used by the dynamics. However, +# we still keep them here in one place, since some other things (sim, estimator, firmware) might need them. +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.012817578393224994 # in N per motor +thrust_max = 0.12 # in N per motor +vmotor2thrust = [-0.014830744918356092, 0.04724465241828281, -0.01847364358025878, 0.005960923942142] # Index is order +vmotor2torque = [-3.3261514624778425e-6, 8.109636684075977e-5, 5.751459172588052e-5, -1.898633582060136e-7] # Index is order +vmotor2rpm = [2968.1791506049194, 6647.948592402306] # Index is order +prop_radius = 23.55e-3 # TODO check +prop_inertia = 34.52e-9 # TODO seems off + + +[cf2x_P250] +gravity_vec = [0.0, 0.0, -9.81] +mass = 0.0318 +L = 0.03253 +J = [ # TODO from L250 + [16.8e-6, 0.0, 0.0], + [0.0, 16.8e-6, 0.0], + [0.0, 0.0, 29.8e-6] +] +rpm2thrust = [0.0, -3.6200226530383495e-7, 1.6060924304100328e-10] # Index is order +rpm2torque = [0.0, -2.2665265829562245e-9, 1.1149485566919186e-12] # Index is order +thrust2torque = 0.0069928948992470565 +rotor_dyn_coef = [ 5.172596691828673, 8.14774234381285e-5, 0.0, 0.0002095253491455741,] +rotor_dyn_coef_simple = 7.709730027690284 +thrust_dyn_coef = 7.9435775497736785 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +drag_matrix = [ # This term is from the so_rpy_rotor_drag dynamics + [-0.01351483, 0.0, 0.0 ], + [0.0, -0.01351483, 0.0 ], + [0.0, 0.0, -0.01677452 ] +] +# The following parameters are for the platform, but are maybe not actually used by the dynamics. However, +# we still keep them here in one place, since some other things (sim, estimator, firmware) might need them. +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.012817578393224994 # in N per motor +thrust_max = 0.12 # in N per motor +vmotor2thrust = [-0.02476537915958403, 0.06523793527519485, -0.026792504967750107, 0.006776789303971145] # Index is order +vmotor2torque = [-2.8633106919309745e-5, 0.00011679386117520097, 5.105754520419129e-5, 0.0] # Index is order +vmotor2rpm = [4657.542534331524, 7536.161830990926] # Index is order +prop_radius = 23.4e-3 # TODO check +prop_inertia = 26.97e-9 + + +[cf2x_T350] +gravity_vec = [0.0, 0.0, -9.81] +mass = 0.0379 +L = 0.03253 +J = [ + [15.7e-6, 0.0, 0.0], + [0.0, 17.1e-6, 0.0], + [0.0, 0.0, 30.0e-6] +] +rpm2thrust = [0.0, -7.167227176573658e-7, 2.9401303690194613e-10] # Index is order +rpm2torque = [0.0, 5.815894847811497e-10, 1.331813874166509e-12] # Index is order +thrust2torque = 0.005355990836477486 +rotor_dyn_coef = [ 11.374753209400291, 0.0, 0.0, 0.00037867688499079635,] +rotor_dyn_coef_simple = 11.352970450445243 +thrust_dyn_coef = 11.12424272978587 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +drag_matrix = [ # This term is from the so_rpy_rotor_drag dynamics + [-0.01556697, 0.0, 0.0 ], + [0.0, -0.01556697, 0.0 ], + [0.0, 0.0, -0.02191672 ] +] +# The following parameters are for the platform, but are maybe not actually used by the dynamics. However, +# we still keep them here in one place, since some other things (sim, estimator, firmware) might need them. +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.01922636758983749 # in N per motor +thrust_max = 0.18 # in N per motor +vmotor2thrust = [0.006728127583707208, 0.01011557616217668, 0.010263198062061085, 0.0028358638322392503] # Index is order +vmotor2torque = [-1.2906047901738756e-5, 0.0001436101487030899, 2.794753913624656e-5, 1.3104535533494383e-5] # Index is order +vmotor2rpm = [2977.884883031915, 8101.0293594093055] # Index is order +prop_radius = 25.4e-3 # TODO check +prop_inertia = 38.93e-9 # TODO value from B500, currenty unknown + + +[cf21B_500] +gravity_vec = [0.0, 0.0, -9.81] +mass = 0.04338 +L = 0.035355 +J = [ + [25e-6, 0.0, 0.0], + [0.0, 28e-6, 0.0], + [0.0, 0.0, 49e-6] +] +rpm2thrust = [0.0, -3.133427287299859e-7, 4.407354891648379e-10] # TODO , Index is order +rpm2torque = [0.0, 1.65886356219615e-9, 2.4693477924534137e-12] # TODO , Index is order +thrust2torque = 0.00593893393599368 +rotor_dyn_coef = [ 13.996001897562685, 0.00011093207920685363, 5.933168530682111, 0.00031951312393561264,] +rotor_dyn_coef_simple = 15.416891997523813 +thrust_dyn_coef = 15.09965949800411 +mixing_matrix = [ + [-1.0, -1.0, 1.0, 1.0], + [-1.0, 1.0, 1.0, -1.0], + [-1.0, 1.0, -1.0, 1.0] +] +drag_matrix = [ # This term is from the so_rpy_rotor_drag dynamics + [-0.02149163, 0.0, 0.0 ], + [0.0, -0.02149163, 0.0 ], + [0.0, 0.0, -0.02359736 ] +] +# The following parameters are for the platform, but are maybe not actually used by the dynamics. However, +# we still keep them here in one place, since some other things (sim, estimator, firmware) might need them. +pwm_min = 7000 +pwm_max = 65535 +thrust_min = 0.02136263065537499 # in N per motor +thrust_max = 0.2 # in N per motor +vmotor2thrust = [-0.014058926705279723, 0.04265273261724981, 0.0018327760144017432, 0.0020576974784587178] # Index is order +vmotor2torque = [-0.00016088354909542246, 0.0003960426420309137, -4.6274122414327404e-5, 1.8490661674309596e-5] # TODO, Index is order +vmotor2rpm = [2938.3995608848436, 6001.834195381014] # Index is order +prop_radius = 27.5e-3 # TODO check +prop_inertia = 38.93e-9 \ No newline at end of file diff --git a/crazyflow/dynamics/__init__.py b/crazyflow/dynamics/__init__.py new file mode 100644 index 0000000..b8426d3 --- /dev/null +++ b/crazyflow/dynamics/__init__.py @@ -0,0 +1,60 @@ +"""Quadrotor dynamics for estimation, control, and simulation. + +This package provides numeric and symbolic quadrotor dynamics at multiple fidelity levels. The +dynamics are implemented as pure functions compatible with any Array API backend (NumPy, JAX, +PyTorch, etc.) and with CasADi for symbolic computation. + +The dynamics are at the core of Crazyflow's simulation. However, they are written to be as +self-contained as possible, so that they can be used independently for other purposes, such as state +estimation or control design. + +Use [parametrize][crazyflow.dynamics.parametrize] to bind a dynamics function to a named drone +configuration, and ``available_dynamics`` to enumerate all registered dynamics. +""" + +from typing import Callable + +from crazyflow.dynamics.core import Dynamics, parametrize +from crazyflow.dynamics.first_principles import dynamics as _first_principles_dynamics +from crazyflow.dynamics.so_rpy import dynamics as _so_rpy_dynamics +from crazyflow.dynamics.so_rpy_rotor import dynamics as _so_rpy_rotor_dynamics +from crazyflow.dynamics.so_rpy_rotor_drag import dynamics as _so_rpy_rotor_drag_dynamics + +__all__ = ["parametrize", "available_dynamics", "dynamics_features", "Dynamics"] + + +available_dynamics: dict[str, Callable] = { + "first_principles": _first_principles_dynamics, + "so_rpy": _so_rpy_dynamics, + "so_rpy_rotor": _so_rpy_rotor_dynamics, + "so_rpy_rotor_drag": _so_rpy_rotor_drag_dynamics, +} + + +def dynamics_features(dynamics: Callable) -> dict[str, bool]: + """Return the feature flags declared by a dynamics function. + + Feature flags are set by the [supports][crazyflow.dynamics.core.supports] decorator on each + dynamics function and describe which optional inputs the dynamics accepts. + + Args: + dynamics: A dynamics function, or a ``functools.partial`` wrapping one (as + returned by [parametrize][crazyflow.dynamics.parametrize]). + + Returns: + A dict of feature names to booleans. Currently contains: + - ``"rotor_dynamics"``: ``True`` if the dynamics accepts and integrates + ``rotor_vel``, ``False`` if passing ``rotor_vel`` raises a + ``ValueError``. + + Example: + ```python + from crazyflow.dynamics import dynamics_features + from crazyflow.dynamics.first_principles import dynamics + + dynamics_features(dynamics) # {'rotor_dynamics': True} + ``` + """ + if hasattr(dynamics, "func"): # Is a partial function + return dynamics_features(dynamics.func) + return getattr(dynamics, "__dynamics_features__") diff --git a/crazyflow/dynamics/core.py b/crazyflow/dynamics/core.py new file mode 100644 index 0000000..31c22ae --- /dev/null +++ b/crazyflow/dynamics/core.py @@ -0,0 +1,128 @@ +"""Core tools for registering and capability checking for the drone dynamics.""" + +from __future__ import annotations + +import tomllib +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar + +import numpy as np + +from crazyflow.drones import load_params as load_physical_params +from crazyflow.utils import filter_to_signature, to_xp +from crazyflow.utils import parametrize as _parametrize + +if TYPE_CHECKING: + from types import ModuleType + +F = TypeVar("F", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") + + +def supports(rotor_dynamics: bool = True) -> Callable[[F], F]: + """Decorator that declares which optional inputs a dynamics function supports. + + The decorator attaches a ``__dynamics_features__`` attribute to the wrapper, which + [dynamics_features][crazyflow.dynamics.dynamics_features] reads. + + Args: + rotor_dynamics: Whether the decorated function models rotor velocity dynamics. Set to + ``False`` for models that do not accept or integrate ``rotor_vel`` (e.g. ``so_rpy``). + Defaults to ``True``. + + Returns: + The function decorated with capability flags. + """ + + def decorator(fn: F) -> F: + fn.__dynamics_features__ = {"rotor_dynamics": rotor_dynamics} + return fn + + return decorator + + +def parametrize( + fn: Callable[P, R], drone: str, xp: ModuleType | None = None, device: str | None = None +) -> Callable[P, R]: + """Parametrize a dynamics function with the default dynamics parameters for a drone. + + Args: + fn: The dynamics function to parametrize. + drone: The drone to use. + xp: The array API module to use. If not provided, numpy is used. + device: The device to use. If none, the device is inferred from the xp module. + + Example: + ```{ .python notest } + from crazyflow.dynamics.core import parametrize + from crazyflow.dynamics.first_principles import dynamics + + dynamics_fn = parametrize(dynamics, drone="cf2x_L250") + pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics_fn( + pos=pos, quat=quat, vel=vel, ang_vel=ang_vel, cmd=cmd, rotor_vel=rotor_vel + ) + ``` + + Returns: + The parametrized dynamics function with all keyword argument only parameters filled in. + """ + return _parametrize(fn, drone, load_params, xp=xp, device=device) + + +def load_params( + fn: Callable, drone: str, xp: ModuleType | None = None, device: str | None = None +) -> dict: + """Load and merge physical and dynamics-specific parameters for a drone configuration. + + Reads parameters from two TOML files: + + * ``crazyflow/drones/params.toml`` — physical parameters shared across all dynamics (mass, + inertia, thrust curves, …). + * ``crazyflow/dynamics//params.toml`` — dynamics-specific coefficients (e.g. fitted + RPY coefficients for ``so_rpy``). + + The two dicts are merged (dynamics-specific values take precedence), and ``J_inv`` is computed + from ``J`` and added to the result. + + Args: + fn: The dynamics function for which to load parameters. + drone: Name of the drone configuration, e.g. ``"cf2x_L250"``. Must exist as a section in + both TOML files. + xp: Array API module used to convert parameter values. If ``None``, NumPy is used. + device: The device to use for the arrays. If ``None``, the device is inferred from the xp + module. + + Returns: + A flat dict mapping parameter names to arrays (or scalars) in the requested array namespace. + Always contains at least ``mass``, ``J``, ``J_inv``, ``gravity_vec``, and the + dynamics-specific coefficients for ``dynamics``. + + Raises: + KeyError: If ``drone`` is not found in either TOML file, or if ``dynamics`` does not + correspond to a known sub-package. + """ + assert isinstance(fn, Callable), f"Expected a function, got {type(fn)}" + dynamics = fn.__module__.split(".")[-2] + if dynamics not in Dynamics: + raise KeyError(f"Dynamics `{dynamics}` not found. Available dynamics: {tuple(Dynamics)}") + with open(Path(__file__).parent / f"{dynamics}/params.toml", "rb") as f: + dynamics_params = tomllib.load(f) + if drone not in dynamics_params: + raise KeyError(f"Drone `{drone}` not found in {dynamics}/params.toml") + params = load_physical_params(drone) | dynamics_params[drone] + # Make sure J_inv does not have a dtype fixed before conversion to xp arrays to avoid fixing it + # to np.float64 when other frameworks might prefer a different dtype. + params["J_inv"] = np.linalg.inv(params["J"]).tolist() + return to_xp(filter_to_signature(params, fn), xp=xp, device=device) + + +class Dynamics(str, Enum): + """Dynamics mode for the simulation.""" + + first_principles = "first_principles" + so_rpy = "so_rpy" + so_rpy_rotor = "so_rpy_rotor" + so_rpy_rotor_drag = "so_rpy_rotor_drag" + default = first_principles diff --git a/crazyflow/dynamics/first_principles/__init__.py b/crazyflow/dynamics/first_principles/__init__.py new file mode 100644 index 0000000..6207606 --- /dev/null +++ b/crazyflow/dynamics/first_principles/__init__.py @@ -0,0 +1,91 @@ +r"""Full rigid-body dynamics for a quadrotor. + +This package implements Newton-Euler dynamics based on physical constants: mass, inertia, motor +thrust and torque curves, arm length, and drag coefficients. The command interface is four motor +angular velocities in RPM. No data fitting is required; all parameters are measurable physical +quantities. + +Motor forces and torques are quadratic polynomials in RPM: + +\[ + f_{p,i} = k_0 + k_1 \Omega_i + k_2 \Omega_i^2, \qquad + \tau_{p,i} = m_0 + m_1 \Omega_i + m_2 \Omega_i^2. +\] + +When rotor dynamics are modelled, each motor RPM evolves as: + +\[ + \dot{\Omega}_i = \begin{cases} + c_1 (\Omega_{\mathrm{cmd},i} - \Omega_i) + + c_2 (\Omega_{\mathrm{cmd},i}^2 - \Omega_i^2) + & \Omega_{\mathrm{cmd},i} \geq \Omega_i \\[4pt] + c_3 (\Omega_{\mathrm{cmd},i} - \Omega_i) + + c_4 (\Omega_{\mathrm{cmd},i}^2 - \Omega_i^2) + & \Omega_{\mathrm{cmd},i} < \Omega_i + \end{cases} +\] + +The rigid-body equations of motion are: + +\[ +\begin{aligned} + \dot{\mathbf{p}} &= \mathbf{v}, \\ + \dot{\mathbf{q}} &= \tfrac{1}{2} + \mathbf{q} \otimes \begin{bmatrix} {}^{\mathcal{B}}\boldsymbol{\omega}\\0 \end{bmatrix}, \\ + m\dot{\mathbf{v}} &= m\mathbf{g} + + R\,{}^{\mathcal{B}}\mathbf{f}_t + + R\,{}^{\mathcal{B}}\mathbf{f}_a, \\ + \mathbf{J}\,{}^{\mathcal{B}}\dot{\boldsymbol{\omega}} &= + {}^{\mathcal{B}}\mathbf{t}_\Sigma + - {}^{\mathcal{B}}\boldsymbol{\omega} + \times \mathbf{J}\,{}^{\mathcal{B}}\boldsymbol{\omega}, +\end{aligned} +\] + +where \(R = {}^{\mathcal{I}}R_{\mathcal{B}}(\mathbf{q})\) is the rotation from body to world frame, +and the forces and torques are: + +\[ +\begin{aligned} + {}^{\mathcal{B}}\mathbf{f}_t &= + \mathbf{e}_z \textstyle\sum_{i=1}^{4} f_{p,i}, \\ + {}^{\mathcal{B}}\mathbf{f}_a &= D_b\,R^{\top}\mathbf{v}, \\ + {}^{\mathcal{B}}\mathbf{t}_\Sigma &= + {}^{\mathcal{B}}\mathbf{t}_t + + {}^{\mathcal{B}}\mathbf{t}_d + + {}^{\mathcal{B}}\mathbf{t}_i, +\end{aligned} +\] + +with: + +\[ +\begin{aligned} + {}^{\mathcal{B}}\mathbf{t}_t &= + \frac{l}{\sqrt{2}} + \begin{bmatrix}1&0&0\\0&1&0\\0&0&0\end{bmatrix} + M\,\mathbf{f}_p, \\ + {}^{\mathcal{B}}\mathbf{t}_d &= + \begin{bmatrix}0&0&0\\0&0&0\\0&0&1\end{bmatrix} + M\,\boldsymbol{\tau}_p, \\ + {}^{\mathcal{B}}\mathbf{t}_i &= J_p + \begin{bmatrix} + -{}^{\mathcal{B}}\omega_y\;\mathbf{m}_z^{\top}\boldsymbol{\Omega} \\ + -{}^{\mathcal{B}}\omega_x\;\mathbf{m}_z^{\top}\boldsymbol{\Omega} \\ + \mathbf{m}_z^{\top}\dot{\boldsymbol{\Omega}} + \end{bmatrix}, +\end{aligned} +\] + +where \(D_b\) is the body-frame drag matrix, \(l\) is the motor arm length, \(J_p\) is the propeller +moment of inertia, \(M\) is the \(3\times 4\) mixing matrix, and \(\mathbf{m}_z\) is its last row. +""" + +from crazyflow.dynamics.first_principles.dynamics import ( + Params, + dynamics, + sim_dynamics, + symbolic_dynamics, +) + +__all__ = ["dynamics", "symbolic_dynamics", "sim_dynamics", "Params"] diff --git a/crazyflow/dynamics/first_principles/dynamics.py b/crazyflow/dynamics/first_principles/dynamics.py new file mode 100644 index 0000000..6f56ab1 --- /dev/null +++ b/crazyflow/dynamics/first_principles/dynamics.py @@ -0,0 +1,361 @@ +"""First-principles dynamics-based quadrotor dynamics. + +This module implements full rigid-body dynamics for a quadrotor based on Newton-Euler equations. The +dynamics are parameterised with physical constants (mass, inertia, thrust and torque curves, motor +arm length, drag coefficients) and require no data fitting. Propeller gyroscopic effects are +included. + +The command interface is four motor angular velocities in RPM. + +Both a numeric implementation ([dynamics][crazyflow.dynamics.first_principles.dynamics]) and a +symbolic CasADi implementation +([symbolic_dynamics][crazyflow.dynamics.first_principles.symbolic_dynamics]) are provided. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import casadi as cs +import jax +import jax.numpy as jnp +from array_api_compat import array_namespace +from array_api_compat import device as xp_device +from flax.struct import dataclass +from scipy.spatial.transform import Rotation as R + +import crazyflow.dynamics.symbols as symbols +from crazyflow.dynamics.core import load_params, supports +from crazyflow.dynamics.utils import rotation +from crazyflow.utils import to_xp + +if TYPE_CHECKING: + from jax import Device + + from crazyflow._typing import Array # To be changed to array_api_typing later + from crazyflow.sim.data import SimData + + +@supports(rotor_dynamics=True) +def dynamics( + pos: Array, + quat: Array, + vel: Array, + ang_vel: Array, + cmd: Array, + rotor_vel: Array | None = None, + dist_f: Array | None = None, + dist_t: Array | None = None, + *, + mass: float, + L: float, + prop_inertia: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + rpm2thrust: Array, + rpm2torque: Array, + mixing_matrix: Array, + drag_matrix: Array, + rotor_dyn_coef: Array, +) -> tuple[Array, Array, Array, Array, Array | None]: + r"""First principles dynamics for a quatrotor. + + The command is four motor angular velocities in RPM. Forces and torques are + computed internally using quadratic thrust and torque curves, the mixing matrix, + and the motor arm length. + + Based on the quaternion dynamics from + + Args: + pos: Position of the drone (m). + quat: Quaternion of the drone (xyzw). + vel: Velocity of the drone (m/s). + ang_vel: Angular velocity of the drone (rad/s). + cmd: Motor speeds (RPMs). + rotor_vel: Angular velocity of the 4 motors (RPMs). If None, the commanded thrust is + directly applied. If value is given, thrust dynamics are calculated. + dist_f: Disturbance force (N) in the world frame acting on the CoM. + dist_t: Disturbance torque (Nm) in the world frame acting on the CoM. + + mass: Mass of the drone (kg). + L: Distance from the CoM to the motor (m). + prop_inertia: Inertia of one propeller in z direction (kg m^2). + gravity_vec: Gravity vector (m/s^2). We assume the gravity vector points downwards, e.g. + [0, 0, -9.81]. + J: Inertia matrix (kg m^2). + J_inv: Inverse inertia matrix (1/kg m^2). + rpm2thrust: Propeller force constant (N min^2). + rpm2torque: Propeller torque constant (Nm min^2). + mixing_matrix: Mixing matrix denoting the turn direction of the motors (4x3). + drag_matrix: Drag matrix containing the linear drag coefficients (3x3). + rotor_dyn_coef: Rotor dynamics coefficients. + + + Warning: + Do not use quat_dot directly for integration! Only usage of ang_vel is mathematically + correct. If you still decide to use quat_dot to integrate, ensure unit length! + More information: + """ + xp = array_namespace(pos) + device = xp_device(pos) + mass, L, prop_inertia = to_xp(mass, L, prop_inertia, xp=xp, device=device) + gravity_vec, rpm2thrust = to_xp(gravity_vec, rpm2thrust, xp=xp, device=device) + rpm2torque, J, J_inv = to_xp(rpm2torque, J, J_inv, xp=xp, device=device) + mixing_matrix, rotor_dyn_coef = to_xp(mixing_matrix, rotor_dyn_coef, xp=xp, device=device) + drag_matrix = to_xp(drag_matrix, xp=xp, device=device) + rot = R.from_quat(quat) # from body to world + rot_mat = rot.inv().as_matrix() # from world to body + # Rotor dynamics + if rotor_vel is None: + warnings.warn("Rotor velocity not provided, using commanded rotor velocity.") + rotor_vel, rotor_vel_dot = cmd, None + else: + rotor_vel_dot = xp.where( + cmd > rotor_vel, + rotor_dyn_coef[0] * (cmd - rotor_vel) + rotor_dyn_coef[1] * (cmd**2 - rotor_vel**2), + rotor_dyn_coef[2] * (cmd - rotor_vel) + rotor_dyn_coef[3] * (cmd**2 - rotor_vel**2), + ) + # Creating force and torque vector + forces_motor = rpm2thrust[0] + rpm2thrust[1] * rotor_vel + rpm2thrust[2] * rotor_vel**2 + forces_motor_tot = xp.sum(forces_motor, axis=-1) + zeros = xp.zeros_like(forces_motor_tot) + forces_motor_vec = xp.stack((zeros, zeros, forces_motor_tot), axis=-1) + forces_motor_vec_world = rot.apply(forces_motor_vec) + force_gravity = gravity_vec * mass + force_drag = (rot_mat.mT @ (drag_matrix @ (rot_mat @ vel[..., None])))[..., 0] + + torques_motor = rpm2torque[0] + rpm2torque[1] * rotor_vel + rpm2torque[2] * rotor_vel**2 + torque_thrust = (mixing_matrix @ (forces_motor)[..., None])[..., 0] * xp.stack( + [L, L, xp.asarray(0.0)] + ) + torque_drag = (mixing_matrix @ (torques_motor)[..., None])[..., 0] * xp.stack( + [xp.asarray(0.0), xp.asarray(0.0), xp.asarray(1.0)] + ) + # convert rotor speed from RPM to rad/s for physical calculations + rpm_to_rad = 2 * xp.pi / 60 + rotor_vel_rads = rotor_vel * rpm_to_rad + rotor_vel_dot_rads = ( + rotor_vel_dot * rpm_to_rad if rotor_vel_dot is not None else xp.zeros_like(rotor_vel) + ) + torque_inertia = prop_inertia * xp.stack( + [ + -ang_vel[..., 1] * xp.sum(mixing_matrix[..., -1, :] * rotor_vel_rads, axis=-1), + -ang_vel[..., 0] * xp.sum(mixing_matrix[..., -1, :] * rotor_vel_rads, axis=-1), + xp.sum(mixing_matrix[..., -1, :] * rotor_vel_dot_rads, axis=-1), + ], + axis=-1, + ) + torque_vec = torque_thrust + torque_drag + torque_inertia + + # Linear equation of motion + forces_sum = forces_motor_vec_world + force_gravity + force_drag + if dist_f is not None: + forces_sum = forces_sum + dist_f + + pos_dot = vel + vel_dot = forces_sum / mass + + # Rotational equation of motion + if dist_t is not None: + torque_vec = torque_vec + rot.apply(dist_t, inverse=True) + quat_dot = rotation.ang_vel2quat_dot(quat, ang_vel) + torque_vec = torque_vec - xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + ang_vel_dot = (J_inv @ torque_vec[..., None])[..., 0] + return pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot + + +def symbolic_dynamics( + model_rotor_vel: bool = True, + model_dist_f: bool = False, + model_dist_t: bool = False, + *, + mass: float, + L: float, + prop_inertia: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + rpm2thrust: Array, + rpm2torque: Array, + mixing_matrix: Array, + rotor_dyn_coef: Array, + drag_matrix: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the first-principles dynamics. + + Implements the same dynamics as [dynamics][crazyflow.dynamics.first_principles.dynamics] using + CasADi ``MX`` symbolic expressions, validated to be numerically equivalent. + + Args: + model_rotor_vel: If ``True``, the four motor RPM states are included in ``X`` and rotor + dynamics are modelled. If ``False``, the commanded RPMs are used directly. Defaults to + ``True``. + model_dist_f: If ``True``, a 3-D force disturbance is appended to ``X``. + model_dist_t: If ``True``, a 3-D torque disturbance is appended to ``X``. + mass: Drone mass in kg. + L: Distance from centre of mass to motor in metres. + prop_inertia: Moment of inertia of one propeller about its spin axis in kg m². + gravity_vec: Gravity vector, shape ``(3,)``. + J: Inertia matrix, shape ``(3, 3)``. + J_inv: Inverse inertia matrix, shape ``(3, 3)``. + rpm2thrust: Polynomial coefficients ``[a, b, c]`` for the thrust curve + ``f = a + b * rpm + c * rpm²``. + rpm2torque: Polynomial coefficients ``[a, b, c]`` for the drag-torque curve + ``τ = a + b * rpm + c * rpm²``. + mixing_matrix: Matrix of shape ``(3, 4)`` mapping per-motor forces to body torques. + rotor_dyn_coef: Four rotor dynamics coefficients ``[k_acc1, k_acc2, k_dec1, k_dec2]`` used + in the piecewise-linear spin-up/down model. + drag_matrix: Diagonal ``(3, 3)`` matrix of linear drag coefficients. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 17 when ``model_rotor_vel=True`` (13 otherwise), plus + 3 per enabled disturbance. + * ``X``: State vector ``[pos(3), quat(4), vel(3), ang_vel(3)]``, with ``rotor_vel(4)`` + appended if ``model_rotor_vel=True``. + * ``U``: Input vector ``[rpm_1, rpm_2, rpm_3, rpm_4]``. + * ``Y``: Output ``[pos(3), quat(4)]``. + """ + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel) + if model_rotor_vel: + X = cs.vertcat(X, symbols.rotor_vel) + if model_dist_f: + X = cs.vertcat(X, symbols.dist_f) + if model_dist_t: + X = cs.vertcat(X, symbols.dist_t) + U = symbols.cmd_rotor_vel + + # Defining the dynamics function + if model_rotor_vel: + # Rotor dynamics + rotor_vel_dot = cs.if_else( + U > symbols.rotor_vel, + rotor_dyn_coef[0] * (U - symbols.rotor_vel) + + rotor_dyn_coef[1] * (U**2 - symbols.rotor_vel**2), + rotor_dyn_coef[2] * (U - symbols.rotor_vel) + + rotor_dyn_coef[3] * (U**2 - symbols.rotor_vel**2), + ) + else: + _saved_rotor_vel = symbols.rotor_vel + symbols.rotor_vel = U + # Creating force and torque vector + forces_motor = ( + rpm2thrust[0] + rpm2thrust[1] * symbols.rotor_vel + rpm2thrust[2] * symbols.rotor_vel**2 + ) + forces_motor_vec = cs.vertcat(0.0, 0.0, cs.sum1(forces_motor)) + forces_motor_vec_world = symbols.rot @ forces_motor_vec + force_gravity = gravity_vec * mass + force_drag = symbols.rot @ (drag_matrix @ (symbols.rot.T @ symbols.vel)) + + torques_motor = ( + rpm2torque[0] + rpm2torque[1] * symbols.rotor_vel + rpm2torque[2] * symbols.rotor_vel**2 + ) + torques_thrust = mixing_matrix @ forces_motor * cs.vertcat(L, L, 0.0) + torques_drag = mixing_matrix @ torques_motor * cs.vertcat(0.0, 0.0, 1.0) + # convert rotor speed from RPM to rad/s for physical calculations + rpm_to_rad = 2 * cs.pi / 60 + rotor_vel_rads = symbols.rotor_vel * rpm_to_rad + rotor_vel_dot_rads = rotor_vel_dot * rpm_to_rad if model_rotor_vel else symbols.rotor_vel * 0.0 + torque_inertia = prop_inertia * cs.vertcat( + -symbols.ang_vel[1] * cs.sum(mixing_matrix[-1, :] * rotor_vel_rads), + -symbols.ang_vel[0] * cs.sum(mixing_matrix[-1, :] * rotor_vel_rads), + cs.sum(mixing_matrix[-1, :] * rotor_vel_dot_rads), + ) + torques_motor_vec = torques_thrust + torques_drag + torque_inertia + + # Linear equation of motion + forces_sum = forces_motor_vec_world + force_gravity + force_drag + if model_dist_f: + forces_sum = forces_sum + symbols.dist_f + + pos_dot = symbols.vel + vel_dot = forces_sum / mass + + # Rotational equation of motion + xi = cs.vertcat( + cs.horzcat(0, -symbols.ang_vel.T), cs.horzcat(symbols.ang_vel, -cs.skew(symbols.ang_vel)) + ) + quat_dot = 0.5 * (xi @ symbols.quat) + torques_sum = torques_motor_vec + if model_dist_t: + torques_sum = torques_sum + symbols.rot.T @ symbols.dist_t + ang_vel_dot = J_inv @ (torques_sum - cs.cross(symbols.ang_vel, J @ symbols.ang_vel)) + + if model_rotor_vel: + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot) + else: + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot) + Y = cs.vertcat(symbols.pos, symbols.quat) + + if not model_rotor_vel: + symbols.rotor_vel = _saved_rotor_vel + return X_dot, X, U, Y + + +@dataclass +class Params: + mass: Array # (N, M, 1) + """Mass of the drone.""" + L: Array # (N, M, 1) + """Arm length of the drone.""" + prop_inertia: Array # (N, M, 1) + """Inertia of the propeller.""" + gravity_vec: Array # (N, M, 3) + """Gravity vector of the drone.""" + J: Array # (N, M, 3, 3) + """Inertia matrix of the drone.""" + J_inv: Array # (N, M, 3, 3) + """Inverse of the inertia matrix of the drone.""" + rpm2thrust: Array # (N, M, 1) + """Force constant of the drone.""" + rpm2torque: Array # (N, M, 1) + """Torque constant of the drone.""" + mixing_matrix: Array # (N, M, 3, 4) + """Mixing matrix of the drone.""" + drag_matrix: Array # (N, M, 3, 3) + """Drag matrix of the drone.""" + rotor_dyn_coef: Array # (N, M, 4) + """Rotor speed dynamics time constant of the drone.""" + + @staticmethod + def create(n_worlds: int, n_drones: int, drone: str, device: Device) -> Params: + """Create a default set of parameters for the simulation.""" + p = load_params(dynamics, drone) + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) + return Params( + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), + L=jnp.asarray(p["L"], device=device), + prop_inertia=jnp.asarray(p["prop_inertia"], device=device), + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), + J=J, + J_inv=jnp.linalg.inv(J), + rpm2thrust=jnp.asarray(p["rpm2thrust"], device=device), + rpm2torque=jnp.asarray(p["rpm2torque"], device=device), + mixing_matrix=jnp.asarray(p["mixing_matrix"], device=device), + drag_matrix=jnp.asarray(p["drag_matrix"], device=device), + rotor_dyn_coef=jnp.asarray(p["rotor_dyn_coef"], device=device), + ) + + +def sim_dynamics(data: SimData) -> SimData: + """Compute the forces and torques from the first principle dynamics.""" + params: Params = data.params + vel, _, acc, ang_acc, rotor_acc = dynamics( + pos=data.states.pos, + quat=data.states.quat, + vel=data.states.vel, + ang_vel=data.states.ang_vel, + cmd=data.controls.rotor_vel, + rotor_vel=data.states.rotor_vel, + dist_f=data.states.force, + dist_t=data.states.torque, + **params.__dict__, + ) + states_deriv = data.states_deriv.replace( + vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc, rotor_acc=rotor_acc + ) + return data.replace(states_deriv=states_deriv) diff --git a/crazyflow/dynamics/first_principles/params.toml b/crazyflow/dynamics/first_principles/params.toml new file mode 100644 index 0000000..f699bad --- /dev/null +++ b/crazyflow/dynamics/first_principles/params.toml @@ -0,0 +1,13 @@ +# Since the first principles dynamics only rely on physical parameters, +# which are already defined in data/params.toml, this file is intentionally left empty. + +[cf2x_L250] + + +[cf2x_P250] + + +[cf2x_T350] + + +[cf21B_500] \ No newline at end of file diff --git a/crazyflow/dynamics/so_rpy/__init__.py b/crazyflow/dynamics/so_rpy/__init__.py new file mode 100644 index 0000000..c71d432 --- /dev/null +++ b/crazyflow/dynamics/so_rpy/__init__.py @@ -0,0 +1,43 @@ +r"""Second-order fitted RPY dynamics (no rotor dynamics). + +Rotational dynamics are modelled as a fitted second-order linear system driven by roll, pitch, and +yaw commands. Translational dynamics are driven by the collective thrust command directly, with no +motor spin-up lag. The command interface is ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + +\[ +\begin{aligned} + \dot{\mathbf{p}} &= \mathbf{v}, \\ + m\dot{\mathbf{v}} &= m\mathbf{g} + + (c_{\mathrm{acc}} + c_f F_{\mathrm{cmd}})\,R\,\mathbf{e}_z, \\ + \ddot{\boldsymbol{\psi}} &= + c_{\psi}\,\boldsymbol{\psi} + + c_{\dot{\psi}}\,\dot{\boldsymbol{\psi}} + + c_u\,\mathbf{u}_{\mathrm{rpy}}, +\end{aligned} +\] + +The vector \(\boldsymbol{\psi} = [\phi,\theta,\psi]^{\top}\) holds the roll, pitch, and yaw angles +with rates \(\dot{\boldsymbol{\psi}}\). The coefficients \(c_{\psi}\), \(c_{\dot{\psi}}\), and +\(c_u\) are identified from flight data. + +!!! note + This is the native Euler-angle form, matching + [symbolic_dynamics_euler][crazyflow.dynamics.so_rpy.symbolic_dynamics_euler]. The simulation + does not integrate this state directly. It shares the common ``[pos, quat, vel, ang_vel]`` state + with the other models and advances the orientation from the body angular velocity + \({}^{\mathcal{B}}\boldsymbol{\omega}\), converting \(\ddot{\boldsymbol{\psi}} \leftrightarrow + {}^{\mathcal{B}}\dot{\boldsymbol{\omega}}\) through the kinematic Jacobian at every step. + Integrating from \({}^{\mathcal{B}}\boldsymbol{\omega}\) rather than \(\dot{\boldsymbol{\psi}}\) + makes the discrete trajectory differ slightly from integrating the Euler state directly. The + difference, however, is negligible at our default frequency of 500 Hz. +""" + +from crazyflow.dynamics.so_rpy.dynamics import ( + Params, + dynamics, + sim_dynamics, + symbolic_dynamics, + symbolic_dynamics_euler, +) + +__all__ = ["Params", "dynamics", "sim_dynamics", "symbolic_dynamics", "symbolic_dynamics_euler"] diff --git a/crazyflow/dynamics/so_rpy/dynamics.py b/crazyflow/dynamics/so_rpy/dynamics.py new file mode 100644 index 0000000..a5bffb1 --- /dev/null +++ b/crazyflow/dynamics/so_rpy/dynamics.py @@ -0,0 +1,380 @@ +"""Second-order fitted RPY dynamics (no rotor dynamics). + +This module implements a simplified quadrotor dynamics where the rotational dynamics are modelled as +a fitted second-order linear system driven by roll, pitch, and yaw (RPY) commands, and the +translational dynamics are driven by the collective thrust command. Motor spin-up dynamics are not +modelled. + +The command interface is ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + +Both a numeric implementation ([dynamics][crazyflow.dynamics.so_rpy.dynamics]) and symbolic CasADi +implementations ([symbolic_dynamics][crazyflow.dynamics.so_rpy.symbolic_dynamics], +[symbolic_dynamics_euler][crazyflow.dynamics.so_rpy.symbolic_dynamics_euler]) are provided. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import casadi as cs +import jax +import jax.numpy as jnp +from array_api_compat import array_namespace +from array_api_compat import device as xp_device +from flax.struct import dataclass +from scipy.spatial.transform import Rotation as R + +import crazyflow.dynamics.symbols as symbols +from crazyflow.dynamics.core import load_params, supports +from crazyflow.dynamics.utils import rotation +from crazyflow.utils import to_xp + +if TYPE_CHECKING: + from jax import Device + + from crazyflow._typing import Array # To be changed to array_api_typing later + from crazyflow.sim.data import SimData + + +@supports(rotor_dynamics=False) +def dynamics( + pos: Array, + quat: Array, + vel: Array, + ang_vel: Array, + cmd: Array, + dist_f: Array | None = None, + dist_t: Array | None = None, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[Array, Array, Array, Array]: + """The fitted linear, second order rpy dynamics. + + Converts the state to Euler angles, evaluates ``dynamics_euler``, maps the derivatives back, and + adds the force/torque disturbances. + + Args: + pos: Position of the drone (m). + quat: Quaternion of the drone (xyzw). + vel: Velocity of the drone (m/s). + ang_vel: Angular velocity of the drone (rad/s). + cmd: Roll pitch yaw (rad) and collective thrust (N) command. + dist_f: Disturbance force (N) in the world frame acting on the CoM. + dist_t: Disturbance torque (Nm) in the world frame acting on the CoM. + + mass: Mass of the drone (kg). + gravity_vec: Gravity vector (m/s^2). We assume the gravity vector points downwards, e.g. + [0, 0, -9.81]. + J: Inertia matrix (kg m^2). + J_inv: Inverse inertia matrix (1/kg m^2). + acc_coef: Coefficient for the acceleration (1/s^2). + cmd_f_coef: Coefficient for the collective thrust (N/rad^2). + rpy_coef: Coefficient for the roll pitch yaw dynamics (1/s). + rpy_rates_coef: Coefficient for the roll pitch yaw rates dynamics (1/s^2). + cmd_rpy_coef: Coefficient for the roll pitch yaw command dynamics (1/s). + + Returns: + The derivatives (pos_dot, quat_dot, vel_dot, ang_vel_dot). + """ + xp = array_namespace(pos) + # Convert parameters to correct xp framework + device = xp_device(pos) + mass, J, J_inv = to_xp(mass, J, J_inv, xp=xp, device=device) + # Convert to the native Euler-angle state, evaluate the core dynamics, then map back + rot = R.from_quat(quat) + rpy = rot.as_euler("xyz") + rpy_rates = rotation.ang_vel2rpy_rates(quat, ang_vel) + pos_dot, _, vel_dot, rpy_rates_dot = dynamics_euler( + pos, + rpy, + vel, + rpy_rates, + cmd, + mass=mass, + gravity_vec=gravity_vec, + acc_coef=acc_coef, + cmd_f_coef=cmd_f_coef, + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + ) + + if dist_f is not None: + vel_dot = vel_dot + dist_f / mass # Adding force disturbances to the state + quat_dot = rotation.ang_vel2quat_dot(quat, ang_vel) + ang_vel_dot = rotation.rpy_rates_deriv2ang_vel_deriv(quat, rpy_rates, rpy_rates_dot) + if dist_t is not None: + # adding torque disturbances to the state + # angular acceleration can be converted to total torque given the inertia matrix + torque = (J @ ang_vel_dot[..., None])[..., 0] + torque = torque + xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + # adding torque. TODO: This should be a linear transformation. Can't we just transform the + # disturbance torque to an ang_vel_dot summand directly? + torque = torque + rot.apply(dist_t, inverse=True) + # back to angular acceleration + torque = torque - xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + ang_vel_dot = (J_inv @ torque[..., None])[..., 0] + + return pos_dot, quat_dot, vel_dot, ang_vel_dot + + +def dynamics_euler( + pos: Array, + rpy: Array, + vel: Array, + rpy_rates: Array, + cmd: Array, + *, + mass: float, + gravity_vec: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[Array, Array, Array, Array]: + """Core fitted second-order rpy dynamics in Euler-angle coordinates.""" + xp = array_namespace(pos) + device = xp_device(pos) + mass, gravity_vec = to_xp(mass, gravity_vec, xp=xp, device=device) + acc_coef, cmd_f_coef, rpy_coef = to_xp(acc_coef, cmd_f_coef, rpy_coef, xp=xp, device=device) + rpy_rates_coef, cmd_rpy_coef = to_xp(rpy_rates_coef, cmd_rpy_coef, xp=xp, device=device) + cmd_f = cmd[..., -1] + cmd_rpy = cmd[..., 0:3] + drone_z_axis = R.from_euler("xyz", rpy).as_matrix()[..., -1] + thrust = acc_coef + cmd_f_coef * cmd_f + pos_dot = vel + vel_dot = 1.0 / mass * thrust[..., None] * drone_z_axis + gravity_vec + rpy_rates_dot = rpy_coef * rpy + rpy_rates_coef * rpy_rates + cmd_rpy_coef * cmd_rpy + return pos_dot, rpy_rates, vel_dot, rpy_rates_dot + + +def symbolic_dynamics( + model_dist_f: bool = False, + model_dist_t: bool = False, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the so_rpy dynamics in quaternion form. + + Internally delegates to + [symbolic_dynamics_euler][crazyflow.dynamics.so_rpy.symbolic_dynamics_euler] and converts the + Euler-angle state to quaternion + angular-velocity state so that the interface matches that of + [symbolic_dynamics][crazyflow.dynamics.first_principles.symbolic_dynamics]. + + Args: + model_dist_f: If ``True``, a 3-D force disturbance is appended to ``X``. + model_dist_t: If ``True``, a 3-D torque disturbance is appended to ``X``. + mass: Drone mass in kg. + gravity_vec: Gravity vector, shape ``(3,)``. + J: Inertia matrix, shape ``(3, 3)``. + J_inv: Inverse inertia matrix, shape ``(3, 3)``. + acc_coef: Scalar acceleration offset coefficient. + cmd_f_coef: Collective-thrust-to-acceleration coefficient. + rpy_coef: RPY state feedback coefficient, shape ``(3,)``. + rpy_rates_coef: RPY-rate feedback coefficient, shape ``(3,)``. + cmd_rpy_coef: RPY command feedforward coefficient, shape ``(3,)``. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 13 (or more with disturbance states). + * ``X``: State vector ``[pos(3), quat(4), vel(3), ang_vel(3)]``. + * ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + * ``Y``: Output ``[pos(3), quat(4)]``. + """ + # We need to set the rpy and drpy symbols before building the euler dynamics + _saved_rpy = symbols.rpy + _saved_drpy = symbols.drpy + _rpy_quat = rotation.cs_quat2euler(symbols.quat) + _drpy_quat = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel) + symbols.rpy = _rpy_quat + symbols.drpy = _drpy_quat + X_dot_euler, X_euler, U_euler, Y_euler = symbolic_dynamics_euler( + mass=mass, + gravity_vec=gravity_vec, + acc_coef=acc_coef, + cmd_f_coef=cmd_f_coef, + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + ) + symbols.rpy = _saved_rpy + symbols.drpy = _saved_drpy + + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel) + if model_dist_f: + X = cs.vertcat(X, symbols.dist_f) + if model_dist_t: + X = cs.vertcat(X, symbols.dist_t) + U = U_euler + + # Linear equation of motion + pos_dot = X_dot_euler[0:3] + vel_dot = X_dot_euler[6:9] + if model_dist_f: + # Adding force disturbances to the state + vel_dot = vel_dot + symbols.dist_f / mass + + # Rotational equation of motion + xi = cs.vertcat( + cs.horzcat(0, -symbols.ang_vel.T), cs.horzcat(symbols.ang_vel, -cs.skew(symbols.ang_vel)) + ) + quat_dot = 0.5 * (xi @ symbols.quat) + ang_vel_dot = rotation.cs_rpy_rates_deriv2ang_vel_deriv( + symbols.quat, _drpy_quat, X_dot_euler[9:12] + ) + if model_dist_t: + # adding torque disturbances to the state + # angular acceleration can be converted to total torque + torque = J @ ang_vel_dot + cs.cross(symbols.ang_vel, J @ symbols.ang_vel) + # adding torque + torque = torque + symbols.rot.T @ symbols.dist_t + # back to angular acceleration + ang_vel_dot = J_inv @ (torque - cs.cross(symbols.ang_vel, J @ symbols.ang_vel)) + + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot) + Y = cs.vertcat(symbols.pos, symbols.quat) + + return X_dot, X, U, Y + + +def symbolic_dynamics_euler( + *, + mass: float, + gravity_vec: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the so_rpy dynamics in Euler-angle form. + + This is the native representation of the ``so_rpy`` dynamics. The state uses roll/pitch/yaw and + their rates rather than quaternion + angular velocity, which avoids trigonometric overhead + inside CasADi-based solvers. + + Args: + mass: Drone mass in kg. + gravity_vec: Gravity vector, shape ``(3,)``. + acc_coef: Scalar acceleration offset coefficient. + cmd_f_coef: Collective-thrust-to-acceleration coefficient. + rpy_coef: RPY state feedback coefficient, shape ``(3,)``. + rpy_rates_coef: RPY-rate feedback coefficient, shape ``(3,)``. + cmd_rpy_coef: RPY command feedforward coefficient, shape ``(3,)``. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 12. + * ``X``: State vector ``[pos(3), rpy(3), vel(3), drpy(3)]``. + * ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + * ``Y``: Output ``[pos(3), rpy(3)]``. + """ + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.rpy, symbols.vel, symbols.drpy) + U = symbols.cmd_rpyt + cmd_rpy = U[:3] + cmd_thrust = U[-1] + rot = rotation.cs_rpy2matrix(symbols.rpy) + + # Defining the dynamics function + forces_motor = cmd_thrust + + # Creating force vector + forces_motor_vec = cs.vertcat(0, 0, acc_coef + cmd_f_coef * forces_motor) + + # Linear equation of motion + pos_dot = symbols.vel + vel_dot = rot @ forces_motor_vec / mass + gravity_vec + + ddrpy = rpy_coef * symbols.rpy + rpy_rates_coef * symbols.drpy + cmd_rpy_coef * cmd_rpy + + X_dot = cs.vertcat(pos_dot, symbols.drpy, vel_dot, ddrpy) + Y = cs.vertcat(symbols.pos, symbols.rpy) + + return X_dot, X, U, Y + + +@dataclass +class Params: + mass: Array # (N, M, 1) + """Mass of the drone.""" + + gravity_vec: Array # (N, M, 3) + """Gravity vector of the drone.""" + + J: Array # (N, M, 3, 3) + """Inertia matrix of the drone.""" + + J_inv: Array # (N, M, 3, 3) + """Inverse of the inertia matrix of the drone.""" + + acc_coef: Array # (N, M, 1) + """Coefficient for the acceleration.""" + + cmd_f_coef: Array # (N, M, 1) + """Coefficient for the collective thrust.""" + + rpy_coef: Array # (N, M, 1) + """Coefficient for the roll pitch yaw dynamics.""" + + rpy_rates_coef: Array # (N, M, 1) + """Coefficient for the roll pitch yaw rates dynamics.""" + + cmd_rpy_coef: Array # (N, M, 1) + """Coefficient for the roll pitch yaw command dynamics.""" + + @staticmethod + def create(n_worlds: int, n_drones: int, drone: str, device: Device) -> Params: + """Create a default set of parameters for the simulation.""" + p = load_params(dynamics, drone) + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) + return Params( + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), + J=J, + J_inv=jnp.linalg.inv(J), + acc_coef=jnp.asarray(p["acc_coef"], device=device), + cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), + rpy_coef=jnp.asarray(p["rpy_coef"], device=device), + rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), + cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), + ) + + +def sim_dynamics(data: SimData) -> SimData: + """Compute the forces and torques from the so_rpy dynamics.""" + params: Params = data.params + vel, _, acc, ang_acc = dynamics( + pos=data.states.pos, + quat=data.states.quat, + vel=data.states.vel, + ang_vel=data.states.ang_vel, + cmd=data.controls.attitude.cmd, + dist_f=data.states.force, + dist_t=data.states.torque, + **params.__dict__, + ) + states_deriv = data.states_deriv.replace( + vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc + ) + return data.replace(states_deriv=states_deriv) diff --git a/crazyflow/dynamics/so_rpy/params.toml b/crazyflow/dynamics/so_rpy/params.toml new file mode 100644 index 0000000..2de1657 --- /dev/null +++ b/crazyflow/dynamics/so_rpy/params.toml @@ -0,0 +1,30 @@ +[cf2x_L250] +acc_coef = 0.0 +cmd_f_coef = 0.97605781 +rpy_coef = [-245.67, -245.67, -227.78] +rpy_rates_coef = [-17.32, -17.32, -25.63] +cmd_rpy_coef = [196.18, 196.18, 390.27] + + +[cf2x_P250] +acc_coef = 0.0 +cmd_f_coef = 0.98275823 +rpy_coef = [-319.14, -319.14, -284.28] +rpy_rates_coef = [-20.85, -20.85, -38.43] +cmd_rpy_coef = [263.30, 263.30, 502.58] + + +[cf2x_T350] +acc_coef = 0.0 +cmd_f_coef = 1.0089779349974615 +rpy_coef = [-371.41695523, -371.41695523, -261.99549945] +rpy_rates_coef = [-29.26311118, -29.26311118, -29.74357219] +cmd_rpy_coef = [347.94260321, 347.94260321, 241.06977014] + + +[cf21B_500] +acc_coef = 0.0 +cmd_f_coef = 0.96836458 +rpy_coef = [-188.9910, -188.9910, -138.3109] +rpy_rates_coef = [-12.7803, -12.7803, -16.8485] +cmd_rpy_coef = [138.0834, 138.0834, 198.5161] diff --git a/crazyflow/dynamics/so_rpy_rotor/__init__.py b/crazyflow/dynamics/so_rpy_rotor/__init__.py new file mode 100644 index 0000000..1853dc2 --- /dev/null +++ b/crazyflow/dynamics/so_rpy_rotor/__init__.py @@ -0,0 +1,37 @@ +r"""Second-order fitted RPY dynamics with first-order thrust dynamics. + +Extends ``so_rpy`` by adding a scalar thrust state \(F\) that captures motor spin-up and spin-down +with a first-order lag. Rotational dynamics remain a fitted second-order linear system driven by RPY +commands. The command interface is ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. The ``rotor_vel`` +state is the current thrust in Newtons (not motor RPMs). + +\[ +\begin{aligned} + \dot{F} &= \frac{1}{\tau}(F_{\mathrm{cmd}} - F), \\ + \dot{\mathbf{p}} &= \mathbf{v}, \\ + m\dot{\mathbf{v}} &= m\mathbf{g} + + (c_{\mathrm{acc}} + c_f F)\,R\,\mathbf{e}_z, \\ + \ddot{\boldsymbol{\psi}} &= + c_{\psi}\,\boldsymbol{\psi} + + c_{\dot{\psi}}\,\dot{\boldsymbol{\psi}} + + c_u\,\mathbf{u}_{\mathrm{rpy}}, +\end{aligned} +\] + +where \(\tau\) is the thrust time constant, \(\boldsymbol{\psi} = [\phi,\theta,\psi]^{\top}\) are +the roll/pitch/yaw angles with rates \(\dot{\boldsymbol{\psi}}\), and +\(R = {}^{\mathcal{I}}R_{\mathcal{B}}(\boldsymbol{\psi})\) is the rotation from body to world frame. + +This is the native Euler-angle form. For how the simulation integrates this state in quaternion + +angular velocity coordinates, see [so_rpy][crazyflow.dynamics.so_rpy]. +""" + +from crazyflow.dynamics.so_rpy_rotor.dynamics import ( + Params, + dynamics, + sim_dynamics, + symbolic_dynamics, + symbolic_dynamics_euler, +) + +__all__ = ["Params", "dynamics", "sim_dynamics", "symbolic_dynamics", "symbolic_dynamics_euler"] diff --git a/crazyflow/dynamics/so_rpy_rotor/dynamics.py b/crazyflow/dynamics/so_rpy_rotor/dynamics.py new file mode 100644 index 0000000..6abc700 --- /dev/null +++ b/crazyflow/dynamics/so_rpy_rotor/dynamics.py @@ -0,0 +1,433 @@ +"""Second-order fitted RPY dynamics with first-order thrust dynamics. + +This module extends the ``so_rpy`` dynamics by adding a scalar thrust state that models motor +spin-up and spin-down with a first-order lag. Rotational dynamics are still modelled as a fitted +second-order linear system driven by RPY commands. + +The command interface is ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. The ``rotor_vel`` state is a +**scalar thrust state in Newtons** (not motor RPMs). + +Both a numeric implementation ([dynamics][crazyflow.dynamics.so_rpy_rotor.dynamics]) and symbolic +CasADi implementations ([symbolic_dynamics][crazyflow.dynamics.so_rpy_rotor.symbolic_dynamics], +[symbolic_dynamics_euler][crazyflow.dynamics.so_rpy_rotor.symbolic_dynamics_euler]) are provided. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import casadi as cs +import jax +import jax.numpy as jnp +from array_api_compat import array_namespace +from array_api_compat import device as xp_device +from flax.struct import dataclass +from scipy.spatial.transform import Rotation as R + +import crazyflow.dynamics.symbols as symbols +from crazyflow.dynamics.core import load_params, supports +from crazyflow.dynamics.utils import rotation +from crazyflow.utils import to_xp + +if TYPE_CHECKING: + from jax import Device + + from crazyflow._typing import Array # To be changed to array_api_typing later + from crazyflow.sim.data import SimData + + +@supports(rotor_dynamics=True) +def dynamics( + pos: Array, + quat: Array, + vel: Array, + ang_vel: Array, + cmd: Array, + rotor_vel: Array | None = None, + dist_f: Array | None = None, + dist_t: Array | None = None, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[Array, Array, Array, Array, Array | None]: + """Fitted linear, second order rpy dynamics with thrust dynamics. + + Converts the state to Euler angles, evaluates ``dynamics_euler``, maps the derivatives back, and + adds the force/torque disturbances. + + Args: + pos: Position of the drone (m). + quat: Quaternion of the drone (xyzw). + vel: Velocity of the drone (m/s). + ang_vel: Angular velocity of the drone (rad/s). + cmd: Roll pitch yaw (rad) and collective thrust (N) command. + rotor_vel: Speed of the 4 motors (RPMs). If None, the commanded thrust is directly + applied (not recommended). If value is given, rotor dynamics are calculated. + dist_f: Disturbance force (N) in the world frame acting on the CoM. + dist_t: Disturbance torque (Nm) in the world frame acting on the CoM. + + mass: Mass of the drone (kg). + gravity_vec: Gravity vector (m/s^2). We assume the gravity vector points downwards, e.g. + [0, 0, -9.81]. + J: Inertia matrix (kg m^2). + J_inv: Inverse inertia matrix (1/kg m^2). + thrust_time_coef: Coefficient for the rotor dynamics (1/s). + acc_coef: Coefficient for the acceleration (1/s^2). + cmd_f_coef: Coefficient for the collective thrust (N/rad^2). + rpy_coef: Coefficient for the roll pitch yaw dynamics (1/s). + rpy_rates_coef: Coefficient for the roll pitch yaw rates dynamics (1/s^2). + cmd_rpy_coef: Coefficient for the roll pitch yaw command dynamics (1/s). + + Returns: + The derivatives (pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot). + """ + xp = array_namespace(pos) + # Convert parameters to correct xp framework + device = xp_device(pos) + mass, J, J_inv = to_xp(mass, J, J_inv, xp=xp, device=device) + + # Convert to the native Euler-angle state, evaluate the core dynamics, then map back + rot = R.from_quat(quat) + rpy = rot.as_euler("xyz") + rpy_rates = rotation.ang_vel2rpy_rates(quat, ang_vel) + pos_dot, _, vel_dot, rpy_rates_dot, rotor_vel_dot = dynamics_euler( + pos, + rpy, + vel, + rpy_rates, + cmd, + rotor_vel, + mass=mass, + gravity_vec=gravity_vec, + thrust_time_coef=thrust_time_coef, + acc_coef=acc_coef, + cmd_f_coef=cmd_f_coef, + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + ) + + if dist_f is not None: + vel_dot = vel_dot + dist_f / mass # Adding force disturbances to the state + quat_dot = rotation.ang_vel2quat_dot(quat, ang_vel) + ang_vel_dot = rotation.rpy_rates_deriv2ang_vel_deriv(quat, rpy_rates, rpy_rates_dot) + if dist_t is not None: + # adding torque disturbances to the state + # angular acceleration can be converted to total torque given the inertia matrix + torque = (J @ ang_vel_dot[..., None])[..., 0] + torque = torque + xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + # adding torque. TODO: This should be a linear transformation. Can't we just transform the + # disturbance torque to an ang_vel_dot summand directly? + torque = torque + rot.apply(dist_t, inverse=True) + # back to angular acceleration + torque = torque - xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + ang_vel_dot = (J_inv @ torque[..., None])[..., 0] + + return pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot + + +def dynamics_euler( + pos: Array, + rpy: Array, + vel: Array, + rpy_rates: Array, + cmd: Array, + rotor_vel: Array | None = None, + *, + mass: float, + gravity_vec: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[Array, Array, Array, Array, Array | None]: + """Core fitted second-order rpy and thrust dynamics in Euler-angle coordinates.""" + xp = array_namespace(pos) + device = xp_device(pos) + mass, gravity_vec = to_xp(mass, gravity_vec, xp=xp, device=device) + thrust_time_coef, acc_coef = to_xp(thrust_time_coef, acc_coef, xp=xp, device=device) + cmd_f_coef, rpy_coef = to_xp(cmd_f_coef, rpy_coef, xp=xp, device=device) + rpy_rates_coef, cmd_rpy_coef = to_xp(rpy_rates_coef, cmd_rpy_coef, xp=xp, device=device) + cmd_f = cmd[..., -1] + cmd_rpy = cmd[..., 0:3] + # Note that we are abusing the rotor_vel state as the thrust + if rotor_vel is None: + warnings.warn("Rotor velocity not provided, using commanded rotor velocity.") + rotor_vel, rotor_vel_dot = cmd_f[..., None], None + else: + rotor_vel_dot = 1 / thrust_time_coef * (cmd_f[..., None] - rotor_vel) + forces_motor = rotor_vel[..., 0] + thrust = acc_coef + cmd_f_coef * forces_motor + drone_z_axis = R.from_euler("xyz", rpy).as_matrix()[..., -1] + pos_dot = vel + vel_dot = 1.0 / mass * thrust[..., None] * drone_z_axis + gravity_vec + rpy_rates_dot = rpy_coef * rpy + rpy_rates_coef * rpy_rates + cmd_rpy_coef * cmd_rpy + return pos_dot, rpy_rates, vel_dot, rpy_rates_dot, rotor_vel_dot + + +def symbolic_dynamics( + model_rotor_vel: bool = True, + model_dist_f: bool = False, + model_dist_t: bool = False, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the so_rpy_rotor dynamics in quaternion form. + + Internally delegates to + [symbolic_dynamics_euler][crazyflow.dynamics.so_rpy_rotor.symbolic_dynamics_euler] and converts + the Euler-angle state to quaternion + angular-velocity state so that the interface matches that + of [symbolic_dynamics][crazyflow.dynamics.first_principles.symbolic_dynamics]. + + Args: + model_rotor_vel: If ``True``, the scalar thrust state is included in ``X`` and first-order + thrust dynamics are modelled. Defaults to ``_True``. + model_dist_f: If ``True``, a 3-D force disturbance is appended to ``X``. + model_dist_t: If ``True``, a 3-D torque disturbance is appended to ``X``. + mass: Drone mass in kg. + gravity_vec: Gravity vector, shape ``(3,)``. + J: Inertia matrix, shape ``(3, 3)``. + J_inv: Inverse inertia matrix, shape ``(3, 3)``. + thrust_time_coef: First-order thrust lag time constant coefficient (1/s). + acc_coef: Scalar acceleration offset coefficient. + cmd_f_coef: Collective-thrust-to-acceleration coefficient. + rpy_coef: RPY state feedback coefficient, shape ``(3,)``. + rpy_rates_coef: RPY-rate feedback coefficient, shape ``(3,)``. + cmd_rpy_coef: RPY command feedforward coefficient, shape ``(3,)``. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 14 when ``model_rotor_vel=True`` (13 otherwise), plus + 3 per enabled disturbance. + * ``X``: State vector ``[pos(3), quat(4), vel(3), ang_vel(3)]``, with ``rotor_vel(1)`` + appended if ``model_rotor_vel=True``. Note that ``rotor_vel`` here represents the thrust + state in Newtons. + * ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + * ``Y``: Output ``[pos(3), quat(4)]``. + """ + ## We need to set the rpy and drpy symbols before building the euler dynamics + _saved_rpy = symbols.rpy + _saved_drpy = symbols.drpy + _rpy_quat = rotation.cs_quat2euler(symbols.quat) + _drpy_quat = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel) + symbols.rpy = _rpy_quat + symbols.drpy = _drpy_quat + X_dot_euler, X_euler, U_euler, Y_euler = symbolic_dynamics_euler( + model_rotor_vel=model_rotor_vel, + mass=mass, + gravity_vec=gravity_vec, + J=J, + J_inv=J_inv, + thrust_time_coef=thrust_time_coef, + acc_coef=acc_coef, + cmd_f_coef=cmd_f_coef, + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + ) + symbols.rpy = _saved_rpy + symbols.drpy = _saved_drpy + + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel) + if model_rotor_vel: + X = cs.vertcat(X, symbols.rotor_vel) + if model_dist_f: + X = cs.vertcat(X, symbols.dist_f) + if model_dist_t: + X = cs.vertcat(X, symbols.dist_t) + U = U_euler + + # Linear equation of motion + pos_dot = X_dot_euler[0:3] + vel_dot = X_dot_euler[6:9] + if model_dist_f: + # Adding force disturbances to the state + vel_dot = vel_dot + symbols.dist_f / mass + + # Rotational equation of motion + xi = cs.vertcat( + cs.horzcat(0, -symbols.ang_vel.T), cs.horzcat(symbols.ang_vel, -cs.skew(symbols.ang_vel)) + ) + quat_dot = 0.5 * (xi @ symbols.quat) + ang_vel_dot = rotation.cs_rpy_rates_deriv2ang_vel_deriv( + symbols.quat, _drpy_quat, X_dot_euler[9:12] + ) + if model_dist_t: + # adding torque disturbances to the state + # angular acceleration can be converted to total torque + torque = J @ ang_vel_dot + cs.cross(symbols.ang_vel, J @ symbols.ang_vel) + # adding torque + torque = torque + symbols.rot.T @ symbols.dist_t + # back to angular acceleration + ang_vel_dot = J_inv @ (torque - cs.cross(symbols.ang_vel, J @ symbols.ang_vel)) + + if model_rotor_vel: + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot, X_dot_euler[-4:]) + else: + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot) + Y = cs.vertcat(symbols.pos, symbols.quat) + + return X_dot, X, U, Y + + +def symbolic_dynamics_euler( + model_rotor_vel: bool = True, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the so_rpy_rotor dynamics in Euler-angle form. + + This is the native representation of the ``so_rpy_rotor`` dynamics. The state uses + roll/pitch/yaw and their rates rather than quaternion + angular velocity, which avoids + trigonometric overhead inside CasADi-based solvers. + + Args: + model_rotor_vel: If ``True``, the scalar thrust state is included in ``X`` and first-order + thrust dynamics are modelled. Defaults to ``True``. + mass: Drone mass in kg. + gravity_vec: Gravity vector, shape ``(3,)``. + J: Inertia matrix, shape ``(3, 3)``. + J_inv: Inverse inertia matrix, shape ``(3, 3)``. + thrust_time_coef: First-order thrust lag time constant coefficient (1/s). + acc_coef: Scalar acceleration offset coefficient. + cmd_f_coef: Collective-thrust-to-acceleration coefficient. + rpy_coef: RPY state feedback coefficient, shape ``(3,)``. + rpy_rates_coef: RPY-rate feedback coefficient, shape ``(3,)``. + cmd_rpy_coef: RPY command feedforward coefficient, shape ``(3,)``. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 13 when ``model_rotor_vel=True`` (12 otherwise). + * ``X``: State vector ``[pos(3), rpy(3), vel(3), drpy(3)]``, with ``rotor_vel(1)`` appended + if ``model_rotor_vel=True``. Note that ``rotor_vel`` here represents the thrust state in + Newtons. + * ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + * ``Y``: Output ``[pos(3), rpy(3)]``. + """ + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.rpy, symbols.vel, symbols.drpy) + if model_rotor_vel: + X = cs.vertcat(X, symbols.rotor_vel) + U = symbols.cmd_rpyt + cmd_rpy = U[:3] + cmd_thrust = U[-1] + rot = rotation.cs_rpy2matrix(symbols.rpy) + + # Defining the dynamics function + # Note that we are abusing the rotor_vel state as the thrust + if model_rotor_vel: + rotor_vel_dot = 1 / thrust_time_coef * (cmd_thrust - symbols.rotor_vel) + forces_motor = symbols.rotor_vel[0] # We are only using the first element + else: + forces_motor = cmd_thrust + + # Creating force vector + forces_motor_vec = cs.vertcat(0, 0, acc_coef + cmd_f_coef * forces_motor) + + # Linear equation of motion + pos_dot = symbols.vel + vel_dot = rot @ forces_motor_vec / mass + gravity_vec + + ddrpy = rpy_coef * symbols.rpy + rpy_rates_coef * symbols.drpy + cmd_rpy_coef * cmd_rpy + + if model_rotor_vel: + X_dot = cs.vertcat(pos_dot, symbols.drpy, vel_dot, ddrpy, rotor_vel_dot) + else: + X_dot = cs.vertcat(pos_dot, symbols.drpy, vel_dot, ddrpy) + Y = cs.vertcat(symbols.pos, symbols.rpy) + + return X_dot, X, U, Y + + +@dataclass +class Params: + mass: Array # (N, M, 1) + """Mass of the drone.""" + gravity_vec: Array # (N, M, 3) + """Gravity vector of the drone.""" + J: Array # (N, M, 3, 3) + """Inertia matrix of the drone.""" + J_inv: Array # (N, M, 3, 3) + """Inverse of the inertia matrix of the drone.""" + thrust_time_coef: Array # (N, M, 1) + """Rotor coefficient of the drone.""" + acc_coef: Array # (N, M, 1) + """Acceleration coefficient of the drone.""" + cmd_f_coef: Array # (N, M, 1) + """Collective thrust coefficient of the drone.""" + rpy_coef: Array # (N, M, 1) + """Roll pitch yaw coefficient of the drone.""" + rpy_rates_coef: Array # (N, M, 1) + """Roll pitch yaw rates coefficient of the drone.""" + cmd_rpy_coef: Array # (N, M, 1) + """Roll pitch yaw command coefficient of the drone.""" + + @staticmethod + def create(n_worlds: int, n_drones: int, drone: str, device: Device) -> Params: + """Create a default set of parameters for the simulation.""" + p = load_params(dynamics, drone) + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) + return Params( + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), + J=J, + J_inv=jnp.linalg.inv(J), + thrust_time_coef=jnp.asarray(p["thrust_time_coef"], device=device), + acc_coef=jnp.asarray(p["acc_coef"], device=device), + cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), + rpy_coef=jnp.asarray(p["rpy_coef"], device=device), + rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), + cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), + ) + + +def sim_dynamics(data: SimData) -> SimData: + """Compute the forces and torques from the so_rpy_rotor dynamics.""" + params: Params = data.params + vel, _, acc, ang_acc, rotor_acc = dynamics( + pos=data.states.pos, + quat=data.states.quat, + vel=data.states.vel, + ang_vel=data.states.ang_vel, + rotor_vel=data.states.rotor_vel, + cmd=data.controls.attitude.cmd, + dist_f=data.states.force, + dist_t=data.states.torque, + **params.__dict__, + ) + states_deriv = data.states_deriv.replace( + vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc, rotor_acc=rotor_acc + ) + return data.replace(states_deriv=states_deriv) diff --git a/crazyflow/dynamics/so_rpy_rotor/params.toml b/crazyflow/dynamics/so_rpy_rotor/params.toml new file mode 100644 index 0000000..edb6a40 --- /dev/null +++ b/crazyflow/dynamics/so_rpy_rotor/params.toml @@ -0,0 +1,34 @@ +[cf2x_L250] +acc_coef = 0.0 +cmd_f_coef = 0.97732585 +thrust_time_coef = 0.0858607 +rpy_coef = [-245.67, -245.67, -227.78] +rpy_rates_coef = [-17.32, -17.32, -25.63] +cmd_rpy_coef = [196.18, 196.18, 390.27] + + +[cf2x_P250] +acc_coef = 0.0 +cmd_f_coef = 0.98323006 +thrust_time_coef = 0.0578952 +rpy_coef = [-319.14, -319.14, -284.28] +rpy_rates_coef = [-20.85, -20.85, -38.43] +cmd_rpy_coef = [263.30, 263.30, 502.58] + + +[cf2x_T350] +acc_coef = 0.0 +cmd_f_coef = 1.022561164673754 +thrust_time_coef = 0.5712805549388994 # High value, maybe not correct? +rpy_coef = [-371.41695523, -371.41695523, -261.99549945] +rpy_rates_coef = [-29.26311118, -29.26311118, -29.74357219] +cmd_rpy_coef = [347.94260321, 347.94260321, 241.06977014] + + +[cf21B_500] +acc_coef = 0.0 +cmd_f_coef = 0.96841816 +thrust_time_coef = 0.02055366 +rpy_coef = [-188.9910, -188.9910, -138.3109] +rpy_rates_coef = [-12.7803, -12.7803, -16.8485] +cmd_rpy_coef = [138.0834, 138.0834, 198.5161] \ No newline at end of file diff --git a/crazyflow/dynamics/so_rpy_rotor_drag/__init__.py b/crazyflow/dynamics/so_rpy_rotor_drag/__init__.py new file mode 100644 index 0000000..1495895 --- /dev/null +++ b/crazyflow/dynamics/so_rpy_rotor_drag/__init__.py @@ -0,0 +1,39 @@ +r"""Second-order fitted RPY dynamics with thrust dynamics and linear drag. + +Extends ``so_rpy_rotor`` by adding a body-frame linear drag term to the translational dynamics. +Rotational dynamics remain a fitted second-order linear system, and thrust spin-up uses a +first-order lag. The command interface is ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. The +``rotor_vel`` state is the current thrust in Newtons (not motor RPMs). + +\[ +\begin{aligned} + \dot{F} &= \frac{1}{\tau}(F_{\mathrm{cmd}} - F), \\ + \dot{\mathbf{p}} &= \mathbf{v}, \\ + m\dot{\mathbf{v}} &= m\mathbf{g} + + (c_{\mathrm{acc}} + c_f F)\,R\,\mathbf{e}_z + + R\,D_b\,R^{\top}\mathbf{v}, \\ + \ddot{\boldsymbol{\psi}} &= + c_{\psi}\,\boldsymbol{\psi} + + c_{\dot{\psi}}\,\dot{\boldsymbol{\psi}} + + c_u\,\mathbf{u}_{\mathrm{rpy}}, +\end{aligned} +\] + +where \(\tau\) is the thrust time constant, \(\boldsymbol{\psi} = [\phi,\theta,\psi]^{\top}\) are +the roll/pitch/yaw angles with rates \(\dot{\boldsymbol{\psi}}\), +\(R = {}^{\mathcal{I}}R_{\mathcal{B}}(\boldsymbol{\psi})\) is the rotation from body to world frame, +and \(D_b\) is the diagonal body-frame aerodynamic drag matrix. + +This is the native Euler-angle form. For how the simulation integrates this state in quaternion + +angular velocity coordinates, see [so_rpy][crazyflow.dynamics.so_rpy]. +""" + +from crazyflow.dynamics.so_rpy_rotor_drag.dynamics import ( + Params, + dynamics, + sim_dynamics, + symbolic_dynamics, + symbolic_dynamics_euler, +) + +__all__ = ["Params", "dynamics", "sim_dynamics", "symbolic_dynamics", "symbolic_dynamics_euler"] diff --git a/crazyflow/dynamics/so_rpy_rotor_drag/dynamics.py b/crazyflow/dynamics/so_rpy_rotor_drag/dynamics.py new file mode 100644 index 0000000..6155a72 --- /dev/null +++ b/crazyflow/dynamics/so_rpy_rotor_drag/dynamics.py @@ -0,0 +1,470 @@ +"""Second-order fitted RPY dynamics with thrust dynamics and linear drag. + +This module extends the ``so_rpy_rotor`` by adding a body-frame linear drag term to the +translational dynamics. Rotational dynamics are still modelled as a fitted second-order linear +system, and thrust spin-up uses a first-order lag. + +The command interface is ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. The ``rotor_vel`` state is a +**scalar thrust state in Newtons** (not motor RPMs). + +Both a numeric implementation ([dynamics][crazyflow.dynamics.so_rpy_rotor_drag.dynamics]) and +symbolic CasADi implementations +([symbolic_dynamics][crazyflow.dynamics.so_rpy_rotor_drag.symbolic_dynamics], +[symbolic_dynamics_euler][crazyflow.dynamics.so_rpy_rotor_drag.symbolic_dynamics_euler]) are +provided. +""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import casadi as cs +import jax +import jax.numpy as jnp +from array_api_compat import array_namespace +from array_api_compat import device as xp_device +from flax.struct import dataclass +from scipy.spatial.transform import Rotation as R + +import crazyflow.dynamics.symbols as symbols +from crazyflow.dynamics.core import load_params, supports +from crazyflow.dynamics.utils import rotation +from crazyflow.utils import to_xp + +if TYPE_CHECKING: + from jax import Device + + from crazyflow._typing import Array # To be changed to array_api_typing later + from crazyflow.sim.data import SimData + + +# Additional symbols specific to these dynamics +roll, pitch, yaw = cs.MX.sym("roll"), cs.MX.sym("pitch"), cs.MX.sym("yaw") +rpy = cs.vertcat(roll, pitch, yaw) # Euler angles +droll, dpitch, dyaw = cs.MX.sym("droll"), cs.MX.sym("dpitch"), cs.MX.sym("dyaw") +drpy = cs.vertcat(droll, dpitch, dyaw) # Euler angle rates +ddroll, ddpitch, ddyaw = cs.MX.sym("ddroll"), cs.MX.sym("ddpitch"), cs.MX.sym("ddyaw") +rpy_ddot = cs.vertcat(ddroll, ddpitch, ddyaw) # Euler angle rates derivatives + + +@supports(rotor_dynamics=True) +def dynamics( + pos: Array, + quat: Array, + vel: Array, + ang_vel: Array, + cmd: Array, + rotor_vel: Array | None = None, + dist_f: Array | None = None, + dist_t: Array | None = None, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, + drag_matrix: Array, +) -> tuple[Array, Array, Array, Array, Array | None]: + """Fitted linear, second order rpy dynamics with thrust dynamics and drag. + + Converts the state to Euler angles, evaluates ``dynamics_euler``, maps the derivatives back, and + adds the force/torque disturbances. + + Args: + pos: Position of the drone (m). + quat: Quaternion of the drone (xyzw). + vel: Velocity of the drone (m/s). + ang_vel: Angular velocity of the drone (rad/s). + cmd: Roll pitch yaw (rad) and collective thrust (N) command. + rotor_vel: Speed of the 4 motors (RPMs). If None, the commanded thrust is directly + applied (not recommended). If value is given, rotor dynamics are calculated. + dist_f: Disturbance force (N) in the world frame acting on the CoM. + dist_t: Disturbance torque (Nm) in the world frame acting on the CoM. + + mass: Mass of the drone (kg). + gravity_vec: Gravity vector (m/s^2). We assume the gravity vector points downwards, e.g. + [0, 0, -9.81]. + J: Inertia matrix (kg m^2). + J_inv: Inverse inertia matrix (1/kg m^2). + thrust_time_coef: Coefficient for the rotor dynamics (1/s). + acc_coef: Coefficient for the acceleration (1/s^2). + cmd_f_coef: Coefficient for the collective thrust (N/rad^2). + rpy_coef: Coefficient for the roll pitch yaw dynamics (1/s). + rpy_rates_coef: Coefficient for the roll pitch yaw rates dynamics (1/s^2). + cmd_rpy_coef: Coefficient for the roll pitch yaw command dynamics (1/s). + drag_matrix: Coefficient matrix for the linear drag (1/s). + + Returns: + The derivatives (pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot). + """ + xp = array_namespace(pos) + # Convert parameters to correct xp framework + device = xp_device(pos) + mass, J, J_inv = to_xp(mass, J, J_inv, xp=xp, device=device) + + # Convert to the native Euler-angle state, evaluate the core dynamics, then map back + rot = R.from_quat(quat) + rpy = rot.as_euler("xyz") + rpy_rates = rotation.ang_vel2rpy_rates(quat, ang_vel) + pos_dot, _, vel_dot, rpy_rates_dot, rotor_vel_dot = dynamics_euler( + pos, + rpy, + vel, + rpy_rates, + cmd, + rotor_vel, + mass=mass, + gravity_vec=gravity_vec, + thrust_time_coef=thrust_time_coef, + acc_coef=acc_coef, + cmd_f_coef=cmd_f_coef, + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + drag_matrix=drag_matrix, + ) + + if dist_f is not None: + vel_dot = vel_dot + dist_f / mass # Adding force disturbances to the state + quat_dot = rotation.ang_vel2quat_dot(quat, ang_vel) + ang_vel_dot = rotation.rpy_rates_deriv2ang_vel_deriv(quat, rpy_rates, rpy_rates_dot) + if dist_t is not None: + # adding torque disturbances to the state + # angular acceleration can be converted to total torque given the inertia matrix + torque = (J @ ang_vel_dot[..., None])[..., 0] + torque = torque + xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + # adding torque. TODO: This should be a linear transformation. Can't we just transform the + # disturbance torque to an ang_vel_dot summand directly? + torque = torque + rot.apply(dist_t, inverse=True) + # back to angular acceleration + torque = torque - xp.linalg.cross(ang_vel, (J @ ang_vel[..., None])[..., 0]) + ang_vel_dot = (J_inv @ torque[..., None])[..., 0] + + return pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot + + +def dynamics_euler( + pos: Array, + rpy: Array, + vel: Array, + rpy_rates: Array, + cmd: Array, + rotor_vel: Array | None = None, + *, + mass: float, + gravity_vec: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, + drag_matrix: Array, +) -> tuple[Array, Array, Array, Array, Array | None]: + """Core second-order rpy, thrust, and drag dynamics in Euler-angle coordinates.""" + xp = array_namespace(pos) + device = xp_device(pos) + mass, gravity_vec = to_xp(mass, gravity_vec, xp=xp, device=device) + thrust_time_coef, acc_coef = to_xp(thrust_time_coef, acc_coef, xp=xp, device=device) + cmd_f_coef, rpy_coef = to_xp(cmd_f_coef, rpy_coef, xp=xp, device=device) + rpy_rates_coef, cmd_rpy_coef = to_xp(rpy_rates_coef, cmd_rpy_coef, xp=xp, device=device) + drag_matrix = to_xp(drag_matrix, xp=xp, device=device) + cmd_f = cmd[..., -1] + cmd_rpy = cmd[..., 0:3] + # Note that we are abusing the rotor_vel state as the thrust + if rotor_vel is None: + warnings.warn("Rotor velocity not provided, using commanded rotor velocity.") + rotor_vel, rotor_vel_dot = cmd_f[..., None], None + else: + rotor_vel_dot = 1 / thrust_time_coef * (cmd_f[..., None] - rotor_vel) + forces_motor = rotor_vel[..., 0] + thrust = acc_coef + cmd_f_coef * forces_motor + rot_mat = R.from_euler("xyz", rpy).inv().as_matrix() # rotation from world to body + drone_z_axis = rot_mat[..., -1, :] + pos_dot = vel + vel_dot = ( + 1 / mass * thrust[..., None] * drone_z_axis + + gravity_vec + + 1 / mass * (rot_mat.mT @ (drag_matrix @ (rot_mat @ vel[..., None])))[..., 0] + ) + rpy_rates_dot = rpy_coef * rpy + rpy_rates_coef * rpy_rates + cmd_rpy_coef * cmd_rpy + return pos_dot, rpy_rates, vel_dot, rpy_rates_dot, rotor_vel_dot + + +def symbolic_dynamics( + model_rotor_vel: bool = True, + model_dist_f: bool = False, + model_dist_t: bool = False, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, + drag_matrix: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the so_rpy_rotor_drag dynamics in quaternion form. + + Internally delegates to + [symbolic_dynamics_euler][crazyflow.dynamics.so_rpy_rotor_drag.symbolic_dynamics_euler] and + converts the Euler-angle state to quaternion + angular-velocity state so that the interface + matches that of [symbolic_dynamics][crazyflow.dynamics.first_principles.symbolic_dynamics]. + + Args: + model_rotor_vel: If ``True``, the scalar thrust state is included in ``X`` and first-order + thrust dynamics are modelled. Defaults to ``True``. + model_dist_f: If ``True``, a 3-D force disturbance is appended to ``X``. + model_dist_t: If ``True``, a 3-D torque disturbance is appended to ``X``. + mass: Drone mass in kg. + gravity_vec: Gravity vector, shape ``(3,)``. + J: Inertia matrix, shape ``(3, 3)``. + J_inv: Inverse inertia matrix, shape ``(3, 3)``. + thrust_time_coef: First-order thrust lag time constant coefficient (1/s). + acc_coef: Scalar acceleration offset coefficient. + cmd_f_coef: Collective-thrust-to-acceleration coefficient. + rpy_coef: RPY state feedback coefficient, shape ``(3,)``. + rpy_rates_coef: RPY-rate feedback coefficient, shape ``(3,)``. + cmd_rpy_coef: RPY command feedforward coefficient, shape ``(3,)``. + drag_matrix: Diagonal ``(3, 3)`` matrix of linear drag coefficients applied in the body + frame. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 14 when ``model_rotor_vel=True`` (13 otherwise), plus + 3 per enabled disturbance. + * ``X``: State vector ``[pos(3), quat(4), vel(3), ang_vel(3)]``, with ``rotor_vel(1)`` + appended if ``model_rotor_vel=True``. Note that ``rotor_vel`` here represents the thrust + state in Newtons. + * ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + * ``Y``: Output ``[pos(3), quat(4)]``. + """ + # Temporarily override rpy/drpy so symbolic_dynamics_euler uses quaternion-derived + # expressions for this call. Restore them afterwards so subsequent calls to + # symbolic_dynamics_euler still get the original leaf symbolic variables. + _saved_rpy = symbols.rpy + _saved_drpy = symbols.drpy + _rpy_quat = rotation.cs_quat2euler(symbols.quat) + _drpy_quat = rotation.cs_ang_vel2rpy_rates(symbols.quat, symbols.ang_vel) + symbols.rpy = _rpy_quat + symbols.drpy = _drpy_quat + X_dot_euler, X_euler, U_euler, Y_euler = symbolic_dynamics_euler( + model_rotor_vel=model_rotor_vel, + mass=mass, + gravity_vec=gravity_vec, + J=J, + J_inv=J_inv, + thrust_time_coef=thrust_time_coef, + acc_coef=acc_coef, + cmd_f_coef=cmd_f_coef, + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + drag_matrix=drag_matrix, + ) + symbols.rpy = _saved_rpy + symbols.drpy = _saved_drpy + + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.quat, symbols.vel, symbols.ang_vel) + if model_rotor_vel: + X = cs.vertcat(X, symbols.rotor_vel) + if model_dist_f: + X = cs.vertcat(X, symbols.dist_f) + if model_dist_t: + X = cs.vertcat(X, symbols.dist_t) + U = U_euler + + # Linear equation of motion + pos_dot = X_dot_euler[0:3] + vel_dot = X_dot_euler[6:9] + if model_dist_f: + # Adding force disturbances to the state + vel_dot = vel_dot + symbols.dist_f / mass + + # Rotational equation of motion + xi = cs.vertcat( + cs.horzcat(0, -symbols.ang_vel.T), cs.horzcat(symbols.ang_vel, -cs.skew(symbols.ang_vel)) + ) + quat_dot = 0.5 * (xi @ symbols.quat) + ang_vel_dot = rotation.cs_rpy_rates_deriv2ang_vel_deriv( + symbols.quat, _drpy_quat, X_dot_euler[9:12] + ) + if model_dist_t: + # adding torque disturbances to the state + # angular acceleration can be converted to total torque + torque = J @ ang_vel_dot + cs.cross(symbols.ang_vel, J @ symbols.ang_vel) + # adding torque + torque = torque + symbols.rot.T @ symbols.dist_t + # back to angular acceleration + ang_vel_dot = J_inv @ (torque - cs.cross(symbols.ang_vel, J @ symbols.ang_vel)) + + if model_rotor_vel: + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot, X_dot_euler[-4:]) + else: + X_dot = cs.vertcat(pos_dot, quat_dot, vel_dot, ang_vel_dot) + Y = cs.vertcat(symbols.pos, symbols.quat) + + return X_dot, X, U, Y + + +def symbolic_dynamics_euler( + model_rotor_vel: bool = True, + *, + mass: float, + gravity_vec: Array, + J: Array, + J_inv: Array, + thrust_time_coef: Array, + acc_coef: Array, + cmd_f_coef: Array, + rpy_coef: Array, + rpy_rates_coef: Array, + cmd_rpy_coef: Array, + drag_matrix: Array, +) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: + """Return CasADi symbolic expressions for the so_rpy_rotor_drag dynamics in Euler-angle form. + + This is the native representation of the ``so_rpy_rotor_drag`` dynamics. The state uses + roll/pitch/yaw and their rates rather than quaternion + angular velocity, which avoids + trigonometric overhead inside CasADi-based solvers. + + Args: + model_rotor_vel: If ``True``, the scalar thrust state is included in ``X`` and first-order + thrust dynamics are modelled. Defaults to ``True``. + mass: Drone mass in kg. + gravity_vec: Gravity vector, shape ``(3,)``. + J: Inertia matrix, shape ``(3, 3)``. + J_inv: Inverse inertia matrix, shape ``(3, 3)``. + thrust_time_coef: First-order thrust lag time constant coefficient (1/s). + acc_coef: Scalar acceleration offset coefficient. + cmd_f_coef: Collective-thrust-to-acceleration coefficient. + rpy_coef: RPY state feedback coefficient, shape ``(3,)``. + rpy_rates_coef: RPY-rate feedback coefficient, shape ``(3,)``. + cmd_rpy_coef: RPY command feedforward coefficient, shape ``(3,)``. + drag_matrix: Diagonal ``(3, 3)`` matrix of linear drag coefficients applied in the body + frame. + + Returns: + Tuple ``(X_dot, X, U, Y)`` of CasADi ``MX`` expressions: + + * ``X_dot``: State derivative, length 13 when ``model_rotor_vel=True`` (12 otherwise). + * ``X``: State vector ``[pos(3), rpy(3), vel(3), drpy(3)]``, with ``rotor_vel(1)`` appended + if ``model_rotor_vel=True``. Note that ``rotor_vel`` here represents the thrust state in + Newtons. + * ``U``: Input vector ``[roll_rad, pitch_rad, yaw_rad, thrust_N]``. + * ``Y``: Output ``[pos(3), rpy(3)]``. + """ + # States and Inputs + X = cs.vertcat(symbols.pos, symbols.rpy, symbols.vel, symbols.drpy) + if model_rotor_vel: + X = cs.vertcat(X, symbols.rotor_vel) + U = symbols.cmd_rpyt + cmd_rpy = U[:3] + cmd_thrust = U[-1] + rot = rotation.cs_rpy2matrix(symbols.rpy) # rotation matrix from body to world + + # Defining the dynamics function + # Note that we are abusing the rotor_vel state as the thrust + if model_rotor_vel: + rotor_vel_dot = 1 / thrust_time_coef * (cmd_thrust - symbols.rotor_vel) + forces_motor = symbols.rotor_vel[0] # We are only using the first element + else: + forces_motor = cmd_thrust + + # Creating force vector + forces_motor_vec = cs.vertcat(0, 0, acc_coef + cmd_f_coef * forces_motor) + + # Linear equation of motion + pos_dot = symbols.vel + vel_dot = ( + 1 / mass * rot @ forces_motor_vec + + gravity_vec + + 1 / mass * rot @ drag_matrix @ rot.T @ symbols.vel + ) + + ddrpy = rpy_coef * symbols.rpy + rpy_rates_coef * symbols.drpy + cmd_rpy_coef * cmd_rpy + + if model_rotor_vel: + X_dot = cs.vertcat(pos_dot, symbols.drpy, vel_dot, ddrpy, rotor_vel_dot) + else: + X_dot = cs.vertcat(pos_dot, symbols.drpy, vel_dot, ddrpy) + Y = cs.vertcat(symbols.pos, symbols.rpy) + + return X_dot, X, U, Y + + +@dataclass +class Params: + mass: Array # (N, M, 1) + """Mass of the drone.""" + gravity_vec: Array # (N, M, 3) + """Gravity vector of the drone.""" + J: Array # (N, M, 3, 3) + """Inertia matrix of the drone.""" + J_inv: Array # (N, M, 3, 3) + """Inverse of the inertia matrix of the drone.""" + thrust_time_coef: Array # (N, M, 1) + """Rotor coefficient of the drone.""" + acc_coef: Array # (N, M, 1) + """Acceleration coefficient of the drone.""" + cmd_f_coef: Array # (N, M, 1) + """Collective thrust coefficient of the drone.""" + rpy_coef: Array # (N, M, 1) + """Roll pitch yaw coefficient of the drone.""" + rpy_rates_coef: Array # (N, M, 1) + """Roll pitch yaw rates coefficient of the drone.""" + cmd_rpy_coef: Array # (N, M, 1) + """Roll pitch yaw command coefficient of the drone.""" + drag_matrix: Array # (N, M, 3, 3) + """Linear drag coefficient matrix of the drone.""" + + @staticmethod + def create(n_worlds: int, n_drones: int, drone: str, device: Device) -> Params: + """Create a default set of parameters for the simulation.""" + p = load_params(dynamics, drone) + J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) + return Params( + mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), + gravity_vec=jnp.asarray(p["gravity_vec"], device=device), + J=J, + J_inv=jnp.linalg.inv(J), + thrust_time_coef=jnp.asarray(p["thrust_time_coef"], device=device), + acc_coef=jnp.asarray(p["acc_coef"], device=device), + cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), + rpy_coef=jnp.asarray(p["rpy_coef"], device=device), + rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), + cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), + drag_matrix=jnp.asarray(p["drag_matrix"], device=device), + ) + + +def sim_dynamics(data: SimData) -> SimData: + """Compute the forces and torques from the so_rpy_rotor_drag dynamics.""" + params: Params = data.params + vel, _, acc, ang_acc, rotor_acc = dynamics( + pos=data.states.pos, + quat=data.states.quat, + vel=data.states.vel, + ang_vel=data.states.ang_vel, + cmd=data.controls.attitude.cmd, + rotor_vel=data.states.rotor_vel, + dist_f=data.states.force, + dist_t=data.states.torque, + **params.__dict__, + ) + states_deriv = data.states_deriv.replace( + vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc, rotor_acc=rotor_acc + ) + return data.replace(states_deriv=states_deriv) diff --git a/crazyflow/dynamics/so_rpy_rotor_drag/params.toml b/crazyflow/dynamics/so_rpy_rotor_drag/params.toml new file mode 100644 index 0000000..3c37ef1 --- /dev/null +++ b/crazyflow/dynamics/so_rpy_rotor_drag/params.toml @@ -0,0 +1,54 @@ +[cf2x_L250] +acc_coef = 0.0 +cmd_f_coef = 0.98325003 +thrust_time_coef = 0.12116392 +drag_matrix = [ + [-0.01471782, 0.0, 0.0 ], + [0.0, -0.01471782, 0.0 ], + [0.0, 0.0, -0.01277641 ] +] +rpy_coef = [-245.67, -245.67, -227.78] +rpy_rates_coef = [-17.32, -17.32, -25.63] +cmd_rpy_coef = [196.18, 196.18, 390.27] + + +[cf2x_P250] +acc_coef = 0.0 +cmd_f_coef = 0.99085062 +thrust_time_coef = 0.11554009 +drag_matrix = [ + [-0.01351483, 0.0, 0.0 ], + [0.0, -0.01351483, 0.0 ], + [0.0, 0.0, -0.01677452 ] +] +rpy_coef = [-319.14, -319.14, -284.28] +rpy_rates_coef = [-20.85, -20.85, -38.43] +cmd_rpy_coef = [263.30, 263.30, 502.58] + + +[cf2x_T350] +acc_coef = 0.0 +cmd_f_coef = 1.0226418398769022 +thrust_time_coef = 0.16711124468068936 +drag_matrix = [ + [-0.01521728, 0.0, 0.0 ], + [0.0, -0.01521728, 0.0 ], + [0.0, 0.0, -0.02144565 ] +] +rpy_coef = [-371.41695523, -371.41695523, -261.99549945] +rpy_rates_coef = [-29.26311118, -29.26311118, -29.74357219] +cmd_rpy_coef = [347.94260321, 347.94260321, 241.06977014] + + +[cf21B_500] +acc_coef = 0.0 +cmd_f_coef = 0.98023254 +thrust_time_coef = 0.07993871 +drag_matrix = [ + [-0.02149163, 0.0, 0.0 ], + [0.0, -0.02149163, 0.0 ], + [0.0, 0.0, -0.02359736 ] +] +rpy_coef = [-188.9910, -188.9910, -138.3109] +rpy_rates_coef = [-12.7803, -12.7803, -16.8485] +cmd_rpy_coef = [138.0834, 138.0834, 198.5161] diff --git a/crazyflow/dynamics/symbols.py b/crazyflow/dynamics/symbols.py new file mode 100644 index 0000000..ebf356f --- /dev/null +++ b/crazyflow/dynamics/symbols.py @@ -0,0 +1,53 @@ +"""Symbols used in the symbolic drone dynamics. + +Can be used to define symbolic CasADi expressions that are passed to model-based optimizers such as +Acados. +""" + +import casadi as cs + +from crazyflow.dynamics.utils import rotation + +# States +px, py, pz = cs.MX.sym("px"), cs.MX.sym("py"), cs.MX.sym("pz") +pos = cs.vertcat(px, py, pz) +"""Symbolic drone position.""" +qw, qx, qy, qz = cs.MX.sym("qw"), cs.MX.sym("qx"), cs.MX.sym("qy"), cs.MX.sym("qz") +quat = cs.vertcat(qx, qy, qz, qw) +"""Symbolic drone orientation as xyzw quaternion.""" +rot = rotation.cs_quat2matrix(quat) # Rotation matrix from body to world frame +vx, vy, vz = cs.MX.sym("vx"), cs.MX.sym("vy"), cs.MX.sym("vz") +vel = cs.vertcat(vx, vy, vz) +"""Symbolic drone velocity.""" +wx, wy, wz = cs.MX.sym("wx"), cs.MX.sym("wy"), cs.MX.sym("wz") +ang_vel = cs.vertcat(wx, wy, wz) +"""Symbolic drone angular velocity.""" +w1, w2, w3, w4 = cs.MX.sym("w1"), cs.MX.sym("w2"), cs.MX.sym("w3"), cs.MX.sym("w4") +rotor_vel = cs.vertcat(w1, w2, w3, w4) +"""Symbolic rotor velocities.""" +dfx, dfy, dfz = cs.MX.sym("dfx"), cs.MX.sym("dfy"), cs.MX.sym("dfz") +dist_f = cs.vertcat(dfx, dfy, dfz) +"""Symbolic disturbance forces.""" +dtx, dty, dtz = cs.MX.sym("dtx"), cs.MX.sym("dty"), cs.MX.sym("dtz") +dist_t = cs.vertcat(dtx, dty, dtz) +"""Symbolic disturbance torques.""" + +# Inputs +cmd_w1, cmd_w2, cmd_w3, cmd_w4 = ( + cs.MX.sym("cmd_w1"), + cs.MX.sym("cmd_w2"), + cs.MX.sym("cmd_w3"), + cs.MX.sym("cmd_w4"), +) +cmd_rotor_vel = cs.vertcat(cmd_w1, cmd_w2, cmd_w3, cmd_w4) +"""Symbolic rotor velocity commands.""" +cmd_roll, cmd_pitch, cmd_yaw = (cs.MX.sym("cmd_roll"), cs.MX.sym("cmd_pitch"), cs.MX.sym("cmd_yaw")) +cmd_thrust = cs.MX.sym("cmd_thrust") +cmd_rpyt = cs.vertcat(cmd_roll, cmd_pitch, cmd_yaw, cmd_thrust) +"""Symbolic roll/pitch/yaw/thrust commands.""" + +# Special states for the so_rpy dynamics +roll, pitch, yaw = cs.MX.sym("roll"), cs.MX.sym("pitch"), cs.MX.sym("yaw") +rpy = cs.vertcat(roll, pitch, yaw) +droll, dpitch, dyaw = cs.MX.sym("droll"), cs.MX.sym("dpitch"), cs.MX.sym("dyaw") +drpy = cs.vertcat(droll, dpitch, dyaw) diff --git a/crazyflow/dynamics/utils/__init__.py b/crazyflow/dynamics/utils/__init__.py new file mode 100644 index 0000000..223fc17 --- /dev/null +++ b/crazyflow/dynamics/utils/__init__.py @@ -0,0 +1 @@ +"""Utility functions for the drone dynamics.""" diff --git a/crazyflow/dynamics/utils/data_utils.py b/crazyflow/dynamics/utils/data_utils.py new file mode 100644 index 0000000..7cc86cc --- /dev/null +++ b/crazyflow/dynamics/utils/data_utils.py @@ -0,0 +1,223 @@ +"""Data preprocessing functions for dynamics identification. + +Contains helpers to filter states and compute derivatives using State Variable Filters. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np +from scipy.integrate import solve_ivp +from scipy.interpolate import interp1d +from scipy.signal import bilinear, butter, filtfilt, lfilter, lfiltic +from scipy.spatial.transform import Rotation as R + +from crazyflow.dynamics.utils.rotation import rpy_rates2ang_vel + +if TYPE_CHECKING: + from crazyflow._typing import Array # To be changed to array_api_typing later + +logger = logging.getLogger(__name__) + + +def preprocessing(data: dict[str, Array]) -> dict[str, Array]: + """Applies preprocessing to collected data. + + The preprocessing includes outlier detection and interpolation, normalizing orientation + (assuming hover at start), calculating rpy from quaternions, and calculating rotational error. + + Args: + data: The raw data dictionary containing time [s], pos [m], quat, cmd_rpy [rad], cmd_f [N]. + + Returns: + The same dict with the following keys added or modified: + + * ``"dt"``: Time step array, shape ``(N-1,)``. + * ``"time"``: Time shifted so ``time[0] == 0``. + * ``"quat"``: Quaternions corrected so the initial attitude is identity. + * ``"rpy"``: Roll/pitch/yaw in radians, shape ``(N, 3)``. + * ``"z_axis"``: Body z-axis in world frame, shape ``(N, 3)``. + * ``"eR"``: Rotation error vector (vee of skew-symmetric error matrix), shape ``(N, 3)``. + * ``"eR_vec"``: Rotation error as rotation vector, shape ``(N, 3)``. + """ + data["dt"] = np.diff(data["time"]) + data["time"] -= data["time"][0] + ### Outlier detection + interpolation + b, a = butter(N=4, Wn=1, fs=1 / np.mean(data["dt"])) + residuals = data["pos"] - filtfilt(b, a, data["pos"], axis=0) + outliers = np.abs(residuals) > 0.3 + outliers = np.sum(outliers, axis=-1) + is_outlier = np.asarray(outliers).astype(bool) + n_outliers = np.sum(outliers) + # TODO also check quat for outliers! + + if n_outliers > 0: + logger.warning(f"{n_outliers} outliers detected. Interpolating") + time_good = data["time"][~is_outlier] + pos_good = data["pos"][~is_outlier] + quat_good = data["quat"][~is_outlier] + interp_pos = interp1d(time_good, pos_good, axis=0, fill_value="extrapolate") + interp_quat = interp1d(time_good, quat_good, axis=0, fill_value="extrapolate") + data["pos"][is_outlier] = interp_pos(data["time"][is_outlier]) + data["quat"][is_outlier] = interp_quat(data["time"][is_outlier]) + + ### Normalizing orientation (assuming zero at start) and calculating rpy + time_span = 0.1 + time_index = int(time_span / np.mean(data["dt"])) + quat_avg = np.mean(data["quat"][:time_index], axis=0) + quat_avg /= np.linalg.norm(quat_avg) + rot_corr = R.from_quat(quat_avg).inv() + rot = rot_corr * R.from_quat(data["quat"]) + data["quat"] = rot.as_quat() + data["rpy"] = rot.as_euler("xyz") + data["z_axis"] = rot.inv().as_matrix()[..., -1, :] + + ### Rotational error + rot = R.from_quat(data["quat"]) + R_act = rot.as_matrix() + R_des = R.from_euler("xyz", data["cmd_rpy"], degrees=False).as_matrix() + eRM = np.matmul(np.swapaxes(R_des, -1, -2), R_act) - np.matmul( + np.swapaxes(R_act, -1, -2), R_des + ) + data["eR"] = np.stack( + (eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1 + ) # vee operator (SO3 to R3) + data["eR_vec"] = (rot.inv() * R.from_euler("xyz", data["cmd_rpy"], degrees=False)).as_rotvec() + + return data + + +def derivatives_svf(data: dict[str, Array]) -> dict[str, Array]: + """Apply a State Variable Filter (SVF) to compute smoothed signals and their time derivatives. + + Filters position, attitude (RPY), and command signals with separate corner frequencies (6 Hz for + translation, 8 Hz for rotation) and computes up to third-order time derivatives. All output + keys are prefixed with ``"SVF_"``. + + Args: + data: Dict produced by [preprocessing][crazyflow.dynamics.utils.data_utils.preprocessing]. + Must contain ``"pos"``, ``"rpy"``, ``"time"``, ``"cmd_f"``, and ``"cmd_rpy"``. + + Returns: + The same dict with the following ``"SVF_"`` keys added: + + * ``"SVF_pos"``, ``"SVF_vel"``, ``"SVF_acc"``, ``"SVF_jerk"``: Filtered position and its + first three derivatives. + * ``"SVF_rpy"``, ``"SVF_drpy"``, ``"SVF_ddrpy"``, ``"SVF_dddrpy"``: Filtered roll/pitch/yaw + and its first three derivatives. + * ``"SVF_quat"``: Quaternion computed from ``SVF_rpy``. + * ``"SVF_z_axis"``: Body z-axis in world frame computed from ``SVF_rpy``. + * ``"SVF_ang_vel"``, ``"SVF_ang_acc"``, ``"SVF_ang_jerk"``: Angular + velocity/acceleration/jerk in body frame. + * ``"SVF_cmd_f"``: Filtered collective thrust command. + * ``"SVF_cmd_rpy"``: Filtered roll/pitch/yaw command. + * ``"SVF_eR"``, ``"SVF_eR_vec"``: Rotation error between actual and commanded attitude. + """ + # Important: Don't mix with unfiltered signals (also for input!) + if data is None: + return None + + svf_linear = state_variable_filter(data["pos"].T, data["time"], f_c=6, N_deriv=3) + data["SVF_pos"] = svf_linear[:, 0].T + data["SVF_vel"] = svf_linear[:, 1].T + data["SVF_acc"] = svf_linear[:, 2].T + data["SVF_jerk"] = svf_linear[:, 3].T + + svf_rotational = state_variable_filter(data["rpy"].T, data["time"], f_c=8, N_deriv=3) + data["SVF_rpy"] = svf_rotational[:, 0].T + data["SVF_drpy"] = svf_rotational[:, 1].T + data["SVF_ddrpy"] = svf_rotational[:, 2].T + data["SVF_dddrpy"] = svf_rotational[:, 3].T + rot = R.from_euler("xyz", data["SVF_rpy"]) + data["SVF_quat"] = rot.as_quat() + data["SVF_z_axis"] = rot.inv().as_matrix()[..., -1, :] + data["SVF_ang_vel"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_drpy"]) + data["SVF_ang_acc"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_ddrpy"]) + data["SVF_ang_jerk"] = rpy_rates2ang_vel(data["SVF_quat"], data["SVF_dddrpy"]) + + svf_input_f = state_variable_filter(data["cmd_f"], data["time"], f_c=6, N_deriv=3) + data["SVF_cmd_f"] = svf_input_f[0] + svf_input_rpy = state_variable_filter(data["cmd_rpy"].T, data["time"], f_c=8, N_deriv=3) + data["SVF_cmd_rpy"] = svf_input_rpy[:, 0].T + + R_act = rot.as_matrix() + rot_cmd = R.from_euler("xyz", data["SVF_cmd_rpy"]) + R_des = rot_cmd.as_matrix() + eRM = np.matmul(np.swapaxes(R_des, -1, -2), R_act) - np.matmul( + np.swapaxes(R_act, -1, -2), R_des + ) + data["SVF_eR"] = np.stack( + (eRM[..., 2, 1], eRM[..., 0, 2], eRM[..., 1, 0]), axis=-1 + ) # vee operator (SO3 to R3) + data["SVF_eR_vec"] = (rot.inv() * rot_cmd).as_rotvec() + + return data + + +def state_variable_filter(y: Array, t: Array, f_c: float = 1, N_deriv: int = 2) -> Array: + """A state variable filter that low pass filters the signal and computes the derivatives. + + Args: + y: The signal to be filtered. Can be 1D (signal_length) or 2D (batch_size, signal_length). + t: The time values for the signal. Optimally fixed sampling frequency. + f_c: Corner frequency of the filter in Hz. Defaults to 1. + N_deriv: Number of derivatives to be computed. Defaults to 2. + + Returns: + Array: The filtered signal and its derivatives. Shape (batch_size, N_deriv+1, signal_length) + """ + if y.ndim == 1: + y = y[None, :] # Add batch dimension if single signal + batch_size, signal_length = y.shape + + # The filter needs to have a minimum of two extra states + # One for the filtered input signal and one for the actual filter + N_ord = N_deriv + 2 + omega_c = 2 * np.pi * f_c + f_s = 1 / np.mean(np.diff(t)) + + b, a = butter(N=N_ord, Wn=omega_c, analog=True) + b_dig, a_dig = bilinear(b, a, fs=f_s) + a_flipped = np.flip(a) + + def _f(t: Array, x: Array, u: Array) -> Array: + x_dot = [] + x_dot_last = 0 + # The first states are a simple integrator chain + for i in np.arange(1, N_ord): + x_dot.append(x[i]) + # Last state uses the filter coefficients + for i in np.arange(0, N_ord): + x_dot_last -= a_flipped[i] * x[i] + x_dot_last += b[0] * u(t) + x_dot.append(x_dot_last) + + return x_dot + + results = np.zeros((batch_size, N_deriv + 1, signal_length)) + + for i in range(batch_size): + # Define input + # Prefilter input backwards to remove time shift + # Add padding to remove filter oscillations in data + pad = 100 + y_backwards = np.flip(y[i], axis=-1) + y_backwards_padded = np.concatenate([np.ones(pad) * y_backwards[0], y_backwards]) + zi = lfiltic( + b_dig, a_dig, y_backwards_padded, x=y_backwards_padded + ) # initial filter conditions + y_backwards, _ = lfilter(b_dig, a_dig, y_backwards_padded, axis=-1, zi=zi) + u = interp1d( + t, np.flip(y_backwards[pad:], axis=-1), kind="linear", fill_value="extrapolate" + ) + + # Solve system with initial conditions + x0 = np.zeros(N_ord) + x0[0] = y[i, 0] + sol = solve_ivp(_f, [t[0], t[-1]], x0, t_eval=t, args=(u,)) + + results[i] = sol.y[:-1] # Last state is not of interest + + return results.squeeze() # Remove batch dim if not needed diff --git a/crazyflow/dynamics/utils/identification.py b/crazyflow/dynamics/utils/identification.py new file mode 100644 index 0000000..47ea99c --- /dev/null +++ b/crazyflow/dynamics/utils/identification.py @@ -0,0 +1,502 @@ +"""This module contains functions to identify so_rpy dynamics from data.""" + +from __future__ import annotations + +import logging +from functools import partial +from typing import TYPE_CHECKING, Callable, Literal + +import jax # noqa: I001 +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from jax.scipy.spatial.transform import Rotation as R # noqa: F401 +from scipy.optimize import least_squares + +from crazyflow.dynamics.so_rpy_rotor_drag import dynamics as dynamics_so_rpy_rotor_drag +from crazyflow.dynamics.utils.rotation import ( # noqa: F401 + ang_vel_deriv2rpy_rates_deriv, + rpy_rates2ang_vel, +) + +if TYPE_CHECKING: + from crazyflow._typing import Array # To be changed to array_api_typing later + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# Remove unused parameters by partial application to make code cleaner +dynamics_translation = partial( + dynamics_so_rpy_rotor_drag, + J=jnp.zeros((3, 3)), + J_inv=jnp.zeros((3, 3)), + rpy_coef=jnp.zeros(3), + rpy_rates_coef=jnp.zeros(3), + cmd_rpy_coef=jnp.zeros(3), +) + +dynamics_rotation = partial( + dynamics_so_rpy_rotor_drag, + mass=0.1, + gravity_vec=jnp.array([0, 0, -9.81]), + thrust_time_coef=0.1, + acc_coef=0.0, + drag_matrix=jnp.zeros((3, 3)), + cmd_f_coef=1.0, + J=jnp.zeros((3, 3)), + J_inv=jnp.zeros((3, 3)), +) + + +def _rmse(x1: Array, x2: Array) -> float: + return np.sqrt(np.mean((x1 - x2) ** 2)) + + +def _r2(x1: Array, x2: Array) -> float: + return 1 - np.sum((x1 - x2) ** 2) / np.sum((x1 - np.mean(x1)) ** 2) + + +# region translation +def _simulate_system_translation( + quat: Array, vel: Array, cmd_f: Array, t: Array, params: Array, constants: dict[str, Array] +) -> Array: + """Simulate the dynamical system and return the derivatives. + + Args: + quat: Orientation (quaternion) of the drone (N, 4) + vel: Velocity of the drone (N, 3) + cmd_f: Commanded thrust (N,) + t: Time samples (N,) + params: Dynamics parameters [cmd_f_coef, thrust_time_coef, drag_xy_coef, drag_z_coef] + constants: Additional constants (mass, gravity_vec, etc.) + + returns: predicted acceleration (N, 3) + """ + N = vel.shape[0] + zeros_Nx3 = jnp.zeros((N, 3)) + zeros_Nx4 = jnp.zeros((N, 4)) + cmd = jnp.concatenate([zeros_Nx3, cmd_f[..., None]], axis=-1) + thrust0 = jnp.array([cmd_f[0]]) # Assuming hover at start, i.e., actual = command + dt = jnp.diff(t) + + # Rollout thrust dynamics + def _step_thrust(carry: Array, inputs: tuple) -> tuple: + dt_step, u = inputs + + _, _, _, _, rotor_vel_dot = dynamics_translation( + pos=zeros_Nx3, + quat=quat, + vel=vel, + ang_vel=zeros_Nx3, + cmd=u, + rotor_vel=carry, + mass=constants["mass"], + gravity_vec=constants["gravity_vec"], + thrust_time_coef=params[1], + acc_coef=0.0, + drag_matrix=jnp.diag(jnp.array([params[2], params[2], params[3]])), + cmd_f_coef=params[0], + ) + x_next = jnp.where(params[1] == 0.0, u[-1], carry + rotor_vel_dot * dt_step) + return x_next, x_next + + _, thrusts = jax.lax.scan(_step_thrust, thrust0, (dt, cmd[:-1])) + # prepend thrust0 to match length + thrusts = jnp.squeeze(thrusts) + thrusts = jnp.concat([thrust0, thrusts], axis=0) + + # Rollout linear dynamics (vectorized) + _, _, acc, _, _ = dynamics_translation( + pos=zeros_Nx3, + quat=quat, + vel=vel, + ang_vel=zeros_Nx3, + cmd=zeros_Nx4, + rotor_vel=thrusts[..., None], + mass=constants["mass"], + gravity_vec=constants["gravity_vec"], + thrust_time_coef=params[1], + acc_coef=0.0, + drag_matrix=jnp.diag(jnp.array([params[2], params[2], params[3]])), + cmd_f_coef=params[0], + ) + return acc + + +def _build_residuals_fun_translation( + dynamics: Literal["so_rpy", "so_rpy_rotor", "so_rpy_rotor_drag"], +) -> tuple[Callable, Callable]: + """Build residual function for the given dynamics type.""" + + def _residuals_trans( + params: Array, + quat: Array, + vel: Array, + cmd_f: Array, + t: Array, + constants: dict[str, Array], + acc_observed: Array, + ) -> Array: + acc = _simulate_system_translation(quat, vel, cmd_f, t, params, constants) + return jnp.linalg.norm(acc_observed - acc, axis=-1) + + # JAX analytic Jacobian + jac_fun = jax.jacfwd(_residuals_trans) # Jacobian w.r.t. first arg (params) + jac_fun = jax.jit(jac_fun) + + def _residual_fun_trans( + params: Array, + quat: Array, + vel: Array, + cmd_f: Array, + t: Array, + constants: dict[str, Array], + acc_observed: Array, + ) -> Callable: + match dynamics: # Dummy values for other params + case "so_rpy": + params_jnp = jnp.array([params[0], 0.0, 0.0, 0.0]) + case "so_rpy_rotor": + params_jnp = jnp.array([params[0], params[1], 0.0, 0.0]) + case "so_rpy_rotor_drag": + params_jnp = jnp.array([params[0], params[1], params[2], params[3]]) + case _: + raise ValueError(f"Unknown dynamics type: {dynamics}") + return jax.device_get( + _residuals_trans(params_jnp, quat, vel, cmd_f, t, constants, acc_observed) + ) + + def _residual_fun_trans_jac( + params: Array, + quat: Array, + vel: Array, + cmd_f: Array, + t: Array, + constants: dict[str, Array], + acc_observed: Array, + ) -> Callable: + match dynamics: # Dummy values for other params + case "so_rpy": + params_jnp = jnp.array([params[0], 0.0, 0.0, 0.0]) + case "so_rpy_rotor": + params_jnp = jnp.array([params[0], params[1], 0.0, 0.0]) + case "so_rpy_rotor_drag": + params_jnp = jnp.array([params[0], params[1], params[2], params[3]]) + case _: + raise ValueError(f"Unknown dynamics type: {dynamics}") + return jax.device_get(jac_fun(params_jnp, quat, vel, cmd_f, t, constants, acc_observed)) + + return _residual_fun_trans, _residual_fun_trans_jac + + +def sys_id_translation( + dynamics: Literal["so_rpy", "so_rpy_rotor", "so_rpy_rotor_drag"], + mass: float, + data: dict[str, Array], + data_validation: dict[str, Array] | None = None, + gravity: Array = np.array([0, 0, -9.81]), + verbose: int = 0, + plot: bool = False, +) -> dict[str, Array]: + """Identify the translational part of the so_rpy dynamics from data. + + Args: + dynamics: Dynamics type to identify. + mass: Mass of the drone. + data: Training data containing time, and the SVF values of vel, acc, quat, cmd_f. + data_validation: Optional validation data containing the same fields as data. + gravity: Gravity vector in world frame, i.e., [0, 0, -9.81]. + verbose: Verbosity level for the optimizer from 0 to 2. + plot: Whether to plot the results. + + Returns: Identified dynamics parameters. + """ + theta0 = [1.0, 1.0, 0.0, 0.0] + method = "trf" + xtol, ftol, gtol = 1e-10, 1e-10, 1e-10 + constants = {"mass": mass, "gravity_vec": gravity} + # Convert the data to jnp arrays for use with jax + t = jnp.array(data["time"]) + vel = jnp.array(data["SVF_vel"]) + acc = jnp.array(data["SVF_acc"]) + quat = jnp.array(data["SVF_quat"]) + cmd_f = jnp.array(data["SVF_cmd_f"]) + + # Identification + residual_fun_trans, residual_fun_trans_jac = _build_residuals_fun_translation(dynamics) + res = least_squares( + residual_fun_trans, + x0=theta0, + jac=residual_fun_trans_jac, + args=(quat, vel, cmd_f, t, constants, acc), + method=method, + xtol=xtol, + ftol=ftol, + gtol=gtol, + verbose=verbose, + ) + + theta = res.x + params = {"cmd_f_coef": theta[0]} + if "rotor" in dynamics: + params["thrust_time_coef"] = theta[1] + else: + theta[1] = 0.0 + if "drag" in dynamics: + params["drag_xy_coef"] = theta[2] + params["drag_z_coef"] = theta[3] + else: + theta[2] = 0.0 + theta[3] = 0.0 + + acc_pred = _simulate_system_translation(quat, vel, cmd_f, t, theta, constants) + if data_validation is not None: + t_valid = jnp.array(data_validation["time"]) + vel_valid = jnp.array(data_validation["SVF_vel"]) + acc_valid = jnp.array(data_validation["SVF_acc"]) + quat_valid = jnp.array(data_validation["SVF_quat"]) + cmd_f_valid = jnp.array(data_validation["SVF_cmd_f"]) + acc_pred_valid = _simulate_system_translation( + quat_valid, vel_valid, cmd_f_valid, t_valid, theta, constants + ) + + # Report + txt = f"\n=== Stats {dynamics} ===" + txt += f"\nParameters: {params=}" + txt += f"\nTraining success={res.success}, results:" + txt += f"\nRMSE={_rmse(acc, acc_pred):.6f}" + txt += f"\nR^2={_r2(acc, acc_pred):.4f}" + if data_validation is not None: + txt += "\nValidation results:" + txt += f"\nRMSE={_rmse(acc_valid, acc_pred_valid):.6f}" + txt += f"\nR^2={_r2(acc_valid, acc_pred_valid):.4f}" + logger.info(txt) + + # Plotting + if plot: + # Plot acceleration + fig, axs = plt.subplots(2, 1, figsize=(12, 5)) + + # Training data subplot + axs[0].plot(t, acc, label="Measured acc") + axs[0].plot(t, acc_pred, "--", label="Predicted acc") + axs[0].set_xlabel("Time [s]") + axs[0].set_ylabel("Output") + + # Validation data subplot + if data_validation is not None: + axs[1].plot(t_valid, acc_valid, label="Measured acc (valid)") + axs[1].plot(t_valid, acc_pred_valid, "--", label="Predicted acc (valid)") + axs[1].set_xlabel("Time [s]") + axs[1].set_ylabel("Output") + + for ax in axs.flat: + ax.grid(True) + ax.legend() + + plt.tight_layout() + plt.show() + + # Plot commanded thrust vs actual thrust + fig, ax = plt.subplots(1, 1, figsize=(6, 6)) + + ax.scatter( + cmd_f, np.linalg.norm((acc - constants["gravity_vec"]) * constants["mass"], axis=-1) + ) + cmd_thrust_lin = np.linspace(np.min(cmd_f) * 0.9, np.max(cmd_f) * 1.1, 1000) + ax.plot(cmd_thrust_lin, theta[0] * cmd_thrust_lin, label="Fit") + ax.set_xlabel("Commanded Thrust [N]") + ax.set_ylabel("Actual Thrust [N]") + ax.set_xlim(0.1, 0.8) + ax.set_ylim(0.1, 0.8) + + plt.tight_layout() + plt.show() + + return params + + +# region rotation +def _simulate_system_rotation(cmd_rpy: Array, t: Array, params: Array) -> Array: + """Simulate the 2nd-order system and return the trajectory. + + Args: + cmd_rpy: Commanded orientation (N, 3) + t: Time samples (N,) + params: Dynamics parameters [rpy_coef, rpy_rates_coef, cmd_rpy_coef] + + returns: predicted acceleration (N, 3) + """ + cmd = jnp.concatenate([cmd_rpy, jnp.zeros((cmd_rpy.shape[0], 1))], axis=-1) + dt = jnp.diff(t) + x0 = jnp.zeros((2, 3)) # rpy, rpy_rates + + def _step_so_system(carry: Array, inputs: Array) -> tuple: + dt_step, cmd = inputs + rpy_coef = jnp.array([params[0], params[0], params[1]]) + rpy_rates_coef = jnp.array([params[2], params[2], params[3]]) + cmd_rpy_coef = jnp.array([params[4], params[4], params[5]]) + rpy, rpy_rates = carry[0], carry[1] + + ### Alternative 1: Using the actual dynamics (slower) + quat = R.from_euler("xyz", rpy).as_quat() + ang_vel = rpy_rates2ang_vel(quat, rpy_rates) + _, _, _, ang_acc, _ = dynamics_rotation( + pos=jnp.array([0.0, 0.0, 0.0]), + quat=quat, + vel=jnp.array([0.0, 0.0, 0.0]), + ang_vel=ang_vel, + cmd=cmd, + rotor_vel=jnp.array([0.0]), + rpy_coef=rpy_coef, + rpy_rates_coef=rpy_rates_coef, + cmd_rpy_coef=cmd_rpy_coef, + ) + drpy_rates = ang_vel_deriv2rpy_rates_deriv(quat, ang_vel, ang_acc) + ### Alternative 2: Using the 2nd-order part directly (faster) + # drpy_rates = rpy_coef * rpy + rpy_rates_coef * rpy_rates + cmd_rpy_coef * cmd[:-1] + + ### Integration + next_rpy = rpy + rpy_rates * dt_step + next_rpy_rates = rpy_rates + drpy_rates * dt_step + x_next = jnp.stack([next_rpy, next_rpy_rates], axis=0) + return x_next, x_next + + _, xs = jax.lax.scan(_step_so_system, x0, (dt, cmd[:-1])) + # prepend x0 to match length + xs = jnp.vstack([jnp.array(x0)[None, :], xs]) + rpy_hat = xs[:, 0] # output y = x1 + return rpy_hat + + +def _build_residuals_fun_rotation() -> tuple[Callable, Callable]: + """Build residual function for the given dynamics type.""" + + def _residuals_rot(params: Array, cmd_rpy: Array, t: Array, rpy_observed: Array) -> Array: + rpy = _simulate_system_rotation(cmd_rpy, t, params) + # return jnp.linalg.norm(rpy_observed - rpy, axis=-1) + return jnp.reshape(rpy_observed - rpy, (-1,)) + + # JAX analytic Jacobian + jac_fun = jax.jacfwd(_residuals_rot) # Jacobian w.r.t. first arg (params) + jac_fun = jax.jit(jac_fun) + + def _residual_fun_rot(params: Array, cmd_rpy: Array, t: Array, rpy_observed: Array) -> Callable: + residuals = jax.jit(_residuals_rot) + return jax.device_get(residuals(params, cmd_rpy, t, rpy_observed)) + + def _residual_fun_rot_jac( + params: Array, cmd_rpy: Array, t: Array, rpy_observed: Array + ) -> Callable: + return jax.device_get(jac_fun(params, cmd_rpy, t, rpy_observed)) + + return _residual_fun_rot, _residual_fun_rot_jac + + +def sys_id_rotation( + data: dict[str, Array], + data_validation: dict[str, Array] | None = None, + verbose: int = 0, + plot: bool = False, +) -> dict[str, Array]: + """Identify the rotational part of the so_rpy dynamics from data. + + Args: + data: Training data containing time, and the SVF values of rpy [rad], cmd_rpy [rad]. + data_validation: Optional validation data containing the same fields as data. + verbose: Verbosity level for the optimizer from 0 to 2. + plot: Whether to plot the results. + + Returns: Identified dynamics parameters. + """ + # theta includes the values for roll/pitch (same value) and yaw + theta0 = np.array([-10.0, -10.0, -1.0, -1.0, 10.0, 10.0]) # ry, ry_rates, cmd_ry + method = "trf" + xtol, ftol, gtol = 1e-10, 1e-10, 1e-10 + t = jnp.array(data["time"]) + rpy = jnp.array(data["SVF_rpy"]) + cmd_rpy = jnp.array(data["SVF_cmd_rpy"]) + if data_validation is not None: + t_valid = jnp.array(data_validation["time"]) + rpy_valid = jnp.array(data_validation["SVF_rpy"]) + cmd_rpy_valid = jnp.array(data_validation["SVF_cmd_rpy"]) + + # Identification + residual_fun_rot, residual_fun_rot_jac = _build_residuals_fun_rotation() + res = least_squares( + residual_fun_rot, + x0=theta0, + jac=residual_fun_rot_jac, + args=(cmd_rpy, t, rpy), + method=method, + xtol=xtol, + ftol=ftol, + gtol=gtol, + verbose=verbose, + ) + theta = res.x + + rpy_coef = np.array([theta[0], theta[0], theta[1]]) + rpy_rates_coef = np.array([theta[2], theta[2], theta[3]]) + cmd_rpy_coef = np.array([theta[4], theta[4], theta[5]]) + params = {"rpy_coef": rpy_coef, "rpy_rates_coef": rpy_rates_coef, "cmd_rpy_coef": cmd_rpy_coef} + + rpy_pred = _simulate_system_rotation(cmd_rpy, t, theta) + if data_validation is not None: + rpy_pred_valid = _simulate_system_rotation(cmd_rpy_valid, t_valid, theta) + + # Report + txt = "\n=== Stats roll & pitch ===" + txt += f"\nEstimated: {rpy_coef=}, {rpy_rates_coef=}, {cmd_rpy_coef=}" + txt += f"\nTraining success={res.success}, results:" + txt += f"\nRMSE={_rmse(rpy, rpy_pred):.6f}" + txt += f"\nR^2={_r2(rpy, rpy_pred):.4f}" + + if data_validation is not None: + txt += "\nValidation results:" + txt += f"\nRMSE roll={_rmse(rpy_valid[..., 0], rpy_pred_valid[..., 0]):.6f}" + txt += f"\nRMSE pitch={_rmse(rpy_valid[..., 1], rpy_pred_valid[..., 1]):.6f}" + txt += f"\nR^2 roll={_r2(rpy_valid[..., 0], rpy_pred_valid[..., 0]):.4f}" + txt += f"\nR^2 pitch={_r2(rpy_valid[..., 1], rpy_pred_valid[..., 1]):.4f}" + logger.info(txt) + + # Plotting + if plot: + fig, axs = plt.subplots(3, 2, figsize=(20, 12)) + plt.suptitle("RPY dynamics fit") + + axs[0, 0].plot(t, rpy[..., 0], label="Measured roll") + axs[0, 0].plot(t, rpy_pred[..., 0], "--", label="Predicted roll") + axs[0, 0].set_ylabel("Roll [rad]") + + axs[0, 1].plot(t_valid, rpy_valid[..., 0], label="Measured roll (valid)") + axs[0, 1].plot(t_valid, rpy_pred_valid[..., 0], "--", label="Predicted roll (valid)") + axs[0, 1].set_ylabel("Roll [rad]") + + axs[1, 0].plot(t, rpy[..., 1], label="Measured pitch") + axs[1, 0].plot(t, rpy_pred[..., 1], "--", label="Predicted pitch") + axs[1, 0].set_ylabel("Pitch [rad]") + + axs[1, 1].plot(t_valid, rpy_valid[..., 1], label="Measured pitch (valid)") + axs[1, 1].plot(t_valid, rpy_pred_valid[..., 1], "--", label="Predicted pitch (valid)") + axs[1, 1].set_ylabel("Pitch [rad]") + + axs[2, 0].plot(t, rpy[..., 2], label="Measured yaw") + axs[2, 0].plot(t, rpy_pred[..., 2], "--", label="Predicted yaw") + axs[2, 0].set_xlabel("Time [s]") + axs[2, 0].set_ylabel("Yaw [rad]") + + axs[2, 1].plot(t_valid, rpy_valid[..., 2], label="Measured yaw (valid)") + axs[2, 1].plot(t_valid, rpy_pred_valid[..., 2], "--", label="Predicted yaw (valid)") + axs[2, 1].set_xlabel("Time [s]") + axs[2, 1].set_ylabel("Yaw [rad]") + + for ax in axs.flat: + ax.grid(True) + ax.legend() + + plt.tight_layout() + plt.show() + + return params diff --git a/crazyflow/dynamics/utils/rotation.py b/crazyflow/dynamics/utils/rotation.py new file mode 100644 index 0000000..3df0fa2 --- /dev/null +++ b/crazyflow/dynamics/utils/rotation.py @@ -0,0 +1,562 @@ +"""Rotation utilities for handling quaternion and Euler angle derivative conversions.""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +import casadi as cs +from array_api_compat import array_namespace +from scipy.spatial.transform import Rotation as R + +if TYPE_CHECKING: + from crazyflow._typing import Array # To be changed to array_api_typing later + +# region Numeric + + +def ang_vel2quat_dot(quat: Array, ang_vel: Array) -> Array: + """Calculates the quaternion derivative based on an angular velocity.""" + xp = array_namespace(quat) + # Split angular velocity + x = ang_vel[..., 0:1] + y = ang_vel[..., 1:2] + z = ang_vel[..., 2:3] + # Skew-symmetric matrix + ang_vel_skew = xp.stack( + [ + xp.concat((xp.zeros_like(x), -z, y), axis=-1), + xp.concat((z, xp.zeros_like(x), -x), axis=-1), + xp.concat((-y, x, xp.zeros_like(x)), axis=-1), + ], + axis=-2, + ) + # First row of Xi + xi1 = xp.concat((xp.zeros_like(x), -ang_vel), axis=-1) + # Second to fourth rows of Xi + ang_vel_col = xp.expand_dims(ang_vel, axis=-1) # (..., 3, 1) + xi2 = xp.concat((ang_vel_col, -ang_vel_skew), axis=-1) # (..., 3, 4) + # Combine into Xi + xi1_exp = xp.expand_dims(xi1, axis=-2) # (..., 1, 4) + xi = xp.concat((xi1_exp, xi2), axis=-2) # (..., 4, 4) + # Quaternion derivative + quat_exp = xp.expand_dims(quat, axis=-1) # (..., 4, 1) + result = 0.5 * xp.matmul(xi, quat_exp) # (..., 4, 1) + return xp.squeeze(result, axis=-1) # (..., 4) + + +def ang_vel2rpy_rates(quat: Array, ang_vel: Array) -> Array: + """Convert angular velocity to rpy rates with batch support.""" + xp = array_namespace(quat) + rpy = R.from_quat(quat).as_euler("xyz") + phi, theta = rpy[..., 0], rpy[..., 1] + + sin_phi = xp.sin(phi) + cos_phi = xp.cos(phi) + cos_theta = xp.cos(theta) + tan_theta = xp.tan(theta) + inv_cos_theta = 1 / cos_theta + + one = xp.ones_like(phi) + zero = xp.zeros_like(phi) + + W = xp.stack( + [ + xp.stack([one, sin_phi * tan_theta, cos_phi * tan_theta], axis=-1), + xp.stack([zero, cos_phi, -sin_phi], axis=-1), + xp.stack([zero, sin_phi * inv_cos_theta, cos_phi * inv_cos_theta], axis=-1), + ], + axis=-2, + ) + + return xp.matmul(W, ang_vel[..., None])[..., 0] + + +def rpy_rates2ang_vel(quat: Array, rpy_rates: Array) -> Array: + """Convert rpy rates to angular velocity with batch support.""" + xp = quat.__array_namespace__() + rpy = R.from_quat(quat).as_euler("xyz") + phi, theta = rpy[..., 0], rpy[..., 1] + + sin_phi = xp.sin(phi) + cos_phi = xp.cos(phi) + cos_theta = xp.cos(theta) + tan_theta = xp.tan(theta) + + one = xp.ones_like(phi) + zero = xp.zeros_like(phi) + + W = xp.stack( + [ + xp.stack([one, zero, -cos_theta * tan_theta], axis=-1), + xp.stack([zero, cos_phi, sin_phi * cos_theta], axis=-1), + xp.stack([zero, -sin_phi, cos_phi * cos_theta], axis=-1), + ], + axis=-2, + ) + + return xp.matmul(W, rpy_rates[..., None])[..., 0] + + +def ang_vel_deriv2rpy_rates_deriv(quat: Array, ang_vel: Array, ang_vel_deriv: Array) -> Array: + r"""Convert rpy rates derivatives to angular velocity derivatives. + + \[ + \dot{\psi} = \mathbf{\dot{W}}\mathbf{\omega} + \mathbf{W} \dot{\mathbf{\omega}} + \] + """ + xp = quat.__array_namespace__() + rpy = R.from_quat(quat).as_euler("xyz") + phi, theta = rpy[..., 0], rpy[..., 1] + rpy_rates = ang_vel2rpy_rates(quat, ang_vel) + phi_dot, theta_dot = rpy_rates[..., 0], rpy_rates[..., 1] + + sin_phi = xp.sin(phi) + cos_phi = xp.cos(phi) + sin_theta = xp.sin(theta) + cos_theta = xp.cos(theta) + tan_theta = xp.tan(theta) + + zero = xp.zeros_like(phi) + + W_dot = xp.stack( + [ + xp.stack( + [ + zero, + cos_phi * phi_dot * tan_theta + sin_phi * theta_dot / cos_theta**2, + -sin_phi * phi_dot * tan_theta + cos_phi * theta_dot / cos_theta**2, + ], + axis=-1, + ), + xp.stack([zero, -sin_phi * phi_dot, -cos_phi * phi_dot], axis=-1), + xp.stack( + [ + zero, + cos_phi * phi_dot / cos_theta + sin_phi * theta_dot * sin_theta / cos_theta**2, + -sin_phi * phi_dot / cos_theta + cos_phi * sin_theta * theta_dot / cos_theta**2, + ], + axis=-1, + ), + ], + axis=-2, + ) + return xp.matmul(W_dot, ang_vel[..., None])[..., 0] + ang_vel2rpy_rates(quat, ang_vel_deriv) + + +def rpy_rates_deriv2ang_vel_deriv(quat: Array, rpy_rates: Array, rpy_rates_deriv: Array) -> Array: + r"""Convert rpy rates derivatives to angular velocity derivatives. + + \[ + \dot{\omega} = \mathbf{\dot{W}}\dot{\mathbf{\psi}} + \mathbf{W} \ddot{\mathbf{\psi}} + \] + """ + xp = quat.__array_namespace__() + rpy = R.from_quat(quat).as_euler("xyz") + phi, theta = rpy[..., 0], rpy[..., 1] + phi_dot, theta_dot = rpy_rates[..., 0], rpy_rates[..., 1] + + sin_phi = xp.sin(phi) + cos_phi = xp.cos(phi) + sin_theta = xp.sin(theta) + cos_theta = xp.cos(theta) + + zero = xp.zeros_like(phi) + + W_dot = xp.stack( + [ + xp.stack([zero, zero, -cos_theta * theta_dot], axis=-1), + xp.stack( + [ + zero, + -sin_phi * phi_dot, + cos_phi * phi_dot * cos_theta - sin_phi * sin_theta * theta_dot, + ], + axis=-1, + ), + xp.stack( + [ + zero, + -cos_phi * phi_dot, + -sin_phi * phi_dot * cos_theta - cos_phi * sin_theta * theta_dot, + ], + axis=-1, + ), + ], + axis=-2, + ) + return xp.matmul(W_dot, rpy_rates[..., None])[..., 0] + rpy_rates2ang_vel(quat, rpy_rates_deriv) + + +# region Symbolic + + +def cs_quat2euler(quat: cs.MX, seq: str = "xyz", degrees: bool = False) -> cs.MX: + """Convert a CasADi symbolic quaternion to Euler angles. + + Symbolic equivalent of ``scipy.spatial.transform.Rotation.from_quat(q).as_euler(seq)``, + implemented in CasADi ``MX`` so it can be differentiated and compiled by + CasADi-based solvers. + + Args: + quat: CasADi ``MX`` column vector of length 4, in scalar-last (xyzw) + convention. + seq: Three-character axis sequence string. Lowercase letters (e.g. + ``"xyz"``) denote extrinsic rotations; uppercase (e.g. ``"XYZ"``) + denote intrinsic rotations. Consecutive axes must differ. + degrees: If ``True``, the returned angles are in degrees. Defaults to + ``False`` (radians). + + Returns: + CasADi ``MX`` column vector of length 3 containing the Euler angles in + the requested sequence and unit. + """ + if len(seq) != 3: + raise ValueError(f"Expected 3 axes, got {len(seq)}.") + + intrinsic = re.match(r"^[XYZ]{1,3}$", seq) is not None + extrinsic = re.match(r"^[xyz]{1,3}$", seq) is not None + + if not (intrinsic or extrinsic): + raise ValueError( + "Expected axes from `seq` to be from ['x', 'y', 'z'] or ['X', 'Y', 'Z'], got {}".format( + seq + ) + ) + + if any(seq[i] == seq[i + 1] for i in range(2)): + raise ValueError("Expected consecutive axes to be different, got {}".format(seq)) + + seq = seq.lower() + + # Compute euler from quat + if extrinsic: + angle_first = 0 + angle_third = 2 + else: + seq = seq[::-1] + angle_first = 2 + angle_third = 0 + + def elementary_basis_index(axis: str) -> int: + """Return the 0-based index (0=x, 1=y, 2=z) for an axis character.""" + if axis == "x": + return 0 + elif axis == "y": + return 1 + else: + return 2 + + i = elementary_basis_index(seq[0]) + j = elementary_basis_index(seq[1]) + k = elementary_basis_index(seq[2]) + + symmetric = i == k + + if symmetric: + k = 3 - i - j # get third axis + + # Check if permutation is even (+1) or odd (-1) + sign = (i - j) * (j - k) * (k - i) // 2 + + eps = 1e-7 + + if symmetric: + a = quat[3] + b = quat[i] + c = quat[j] + d = quat[k] * sign + else: + a = quat[3] - quat[j] + b = quat[i] + quat[k] * sign + c = quat[j] + quat[3] + d = quat[k] * sign - quat[i] + + angles1 = 2.0 * cs.arctan2(cs.sqrt(c**2 + d**2), cs.sqrt(a**2 + b**2)) + + case = cs.if_else( + cs.fabs(angles1) <= eps, 1, cs.if_else(cs.fabs(angles1 - cs.np.pi) <= eps, 2, 0) + ) + + half_sum = cs.arctan2(b, a) + half_diff = cs.arctan2(d, c) + + angles_case_0_ = [None, angles1, None] + angles_case_0_[angle_first] = half_sum - half_diff + angles_case_0_[angle_third] = half_sum + half_diff + angles_case_0 = cs.vertcat(*angles_case_0_) + + angles_case_else_ = [None, angles1, 0.0] + angles_case_else_[0] = cs.if_else( + case == 1, 2.0 * half_sum, 2.0 * half_diff * (-1.0 if extrinsic else 1.0) + ) + angles_case_else = cs.vertcat(*angles_case_else_) + + angles = cs.if_else(case == 0, angles_case_0, angles_case_else) + + if not symmetric: + angles[angle_third] *= sign + angles[1] -= cs.np.pi * 0.5 + + for i in range(3): + angles[i] += cs.if_else( + angles[i] < -cs.np.pi, + 2.0 * cs.np.pi, + cs.if_else(angles[i] > cs.np.pi, -2.0 * cs.np.pi, 0.0), + ) + + if degrees: + angles = (cs.np.pi / 180.0) * cs.horzcat(angles) + + return angles + + +def cs_quat2matrix(quat: cs.MX) -> cs.MX: + """Creates a symbolic rotation matrix based on a symbolic quaternion. + + From + """ + x = quat[0] / cs.norm_2(quat) + y = quat[1] / cs.norm_2(quat) + z = quat[2] / cs.norm_2(quat) + w = quat[3] / cs.norm_2(quat) + + x2 = x * x + y2 = y * y + z2 = z * z + w2 = w * w + + xy = x * y + zw = z * w + xz = x * z + yw = y * w + yz = y * z + xw = x * w + + matrix = cs.horzcat( + cs.vertcat(x2 - y2 - z2 + w2, 2.0 * (xy + zw), 2.0 * (xz - yw)), + cs.vertcat(2.0 * (xy - zw), -x2 + y2 - z2 + w2, 2.0 * (yz + xw)), + cs.vertcat(2.0 * (xz + yw), 2.0 * (yz - xw), -x2 - y2 + z2 + w2), + ) + + return matrix + + +def cs_rpy2matrix(rpy: cs.MX, degrees: bool = False) -> cs.MX: + """Creates a symbolic rotation matrix from roll, pitch, yaw (XYZ convention). + + Should be equivalent to scipy.spatial.transform.Rotation.from_euler('xyz', rpy).as_matrix(). + """ + roll, pitch, yaw = rpy[0], rpy[1], rpy[2] + if degrees: + roll *= cs.pi / 180 + pitch *= cs.pi / 180 + yaw *= cs.pi / 180 + + cr = cs.cos(roll) + sr = cs.sin(roll) + cp = cs.cos(pitch) + sp = cs.sin(pitch) + cy = cs.cos(yaw) + sy = cs.sin(yaw) + + # Rotation matrix for R = Rz(yaw) * Ry(pitch) * Rx(roll) + matrix = cs.vertcat( + cs.horzcat(cy * cp, cy * sp * sr - sy * cr, cy * sp * cr + sy * sr), + cs.horzcat(sy * cp, sy * sp * sr + cy * cr, sy * sp * cr - cy * sr), + cs.horzcat(-sp, cp * sr, cp * cr), + ) + + return matrix + + +# region Wrappers + + +def create_cs_ang_vel2rpy_rates() -> cs.Function: + """Build a compiled CasADi function that converts angular velocity to RPY rates. + + Returns: + A ``casadi.Function`` with signature + ``(quat[4], ang_vel[3]) -> rpy_rates[3]`` + that evaluates the kinematic mapping + ``ṙpy = W(rpy) · ω`` for a given attitude quaternion and body-frame + angular velocity. + """ + qw = cs.MX.sym("qw") + qx = cs.MX.sym("qx") + qy = cs.MX.sym("qy") + qz = cs.MX.sym("qz") + quat = cs.vertcat(qx, qy, qz, qw) # Quaternions + rpy = cs_quat2euler(quat) + phi, theta = rpy[0], rpy[1] + p = cs.MX.sym("p") + q = cs.MX.sym("q") + r = cs.MX.sym("r") + ang_vel = cs.vertcat(p, q, r) # Angular velocity + + row1 = cs.horzcat(1, cs.sin(phi) * cs.tan(theta), cs.cos(phi) * cs.tan(theta)) + row2 = cs.horzcat(0, cs.cos(phi), -cs.sin(phi)) + row3 = cs.horzcat(0, cs.sin(phi) / cs.cos(theta), cs.cos(phi) / cs.cos(theta)) + + W = cs.vertcat(row1, row2, row3) + rpy_rates = W @ ang_vel + + return cs.Function("cs_ang_vel2rpy_rates", [quat, ang_vel], [rpy_rates]) + + +cs_ang_vel2rpy_rates = create_cs_ang_vel2rpy_rates() + + +def create_cs_rpy_rates2ang_vel() -> cs.Function: + """Build a compiled CasADi function that converts RPY rates to angular velocity. + + Returns: + A ``casadi.Function`` with signature + ``(quat[4], rpy_rates[3]) -> ang_vel[3]`` + that evaluates the inverse kinematic mapping + ``ω = W⁻¹(rpy) · ṙpy``. + """ + qw = cs.MX.sym("qw") + qx = cs.MX.sym("qx") + qy = cs.MX.sym("qy") + qz = cs.MX.sym("qz") + quat = cs.vertcat(qx, qy, qz, qw) # Quaternions + rpy = cs_quat2euler(quat) + phi, theta = rpy[0], rpy[1] + phi_dot = cs.MX.sym("phi_dot") + theta_dot = cs.MX.sym("theta_dot") + psi_dot = cs.MX.sym("psi_dot") + rpy_rates = cs.vertcat(phi_dot, theta_dot, psi_dot) # Euler rates + + row1 = cs.horzcat(1, 0, -cs.cos(theta) * cs.tan(theta)) + row2 = cs.horzcat(0, cs.cos(phi), cs.sin(phi) * cs.cos(theta)) + row3 = cs.horzcat(0, -cs.sin(phi), cs.cos(phi) * cs.cos(theta)) + + W = cs.vertcat(row1, row2, row3) + ang_vel = W @ rpy_rates + return cs.Function("cs_rpy_rates2ang_vel", [quat, rpy_rates], [ang_vel]) + + +cs_rpy_rates2ang_vel = create_cs_rpy_rates2ang_vel() + + +def create_cs_ang_vel_deriv2rpy_rates_deriv() -> cs.Function: + """Build a compiled CasADi function that converts angular acceleration to RPY-rate derivatives. + + Returns: + A ``casadi.Function`` with signature + ``(quat[4], ang_vel[3], ang_vel_deriv[3]) -> rpy_rates_deriv[3]`` + implementing ``r̈py = Ẇ · ω + W · ω̇``. + """ + qw = cs.MX.sym("qw") + qx = cs.MX.sym("qx") + qy = cs.MX.sym("qy") + qz = cs.MX.sym("qz") + quat = cs.vertcat(qx, qy, qz, qw) # Quaternions + rpy = cs_quat2euler(quat) + phi, theta = rpy[0], rpy[1] + p = cs.MX.sym("p") + q = cs.MX.sym("q") + r = cs.MX.sym("r") + ang_vel = cs.vertcat(p, q, r) # Angular velocity + p_dot = cs.MX.sym("p_dot") + q_dot = cs.MX.sym("q_dot") + r_dot = cs.MX.sym("r_dot") + ang_vel_deriv = cs.vertcat(p_dot, q_dot, r_dot) # Angular acceleration + rpy_rates = cs_ang_vel2rpy_rates(quat, ang_vel) + phi_dot, theta_dot = rpy_rates[0], rpy_rates[1] + + row1 = cs.horzcat( + 0, + cs.cos(phi) * phi_dot * cs.tan(theta) + cs.sin(phi) * theta_dot / cs.cos(theta) ** 2, + -cs.sin(phi) * phi_dot * cs.tan(theta) + cs.cos(phi) * theta_dot / cs.cos(theta) ** 2, + ) + row2 = cs.horzcat(0, -cs.sin(phi) * phi_dot, -cs.cos(phi) * phi_dot) + row3 = cs.horzcat( + 0, + cs.cos(phi) * phi_dot / cs.cos(theta) + + cs.sin(phi) * theta_dot * cs.sin(theta) / cs.cos(theta) ** 2, + -cs.sin(phi) * phi_dot / cs.cos(theta) + + cs.cos(phi) * cs.sin(theta) * theta_dot / cs.cos(theta) ** 2, + ) + + W_dot = cs.vertcat(row1, row2, row3) + rpy_rates_deriv = W_dot @ ang_vel + cs_ang_vel2rpy_rates(quat, ang_vel_deriv) + + return cs.Function("cs_ang_vel2rpy_rates", [quat, ang_vel, ang_vel_deriv], [rpy_rates_deriv]) + + +cs_ang_vel_deriv2rpy_rates_deriv = create_cs_ang_vel_deriv2rpy_rates_deriv() + + +def create_cs_rpy_rates_deriv2ang_vel_deriv() -> cs.Function: + """Build a compiled CasADi function that converts RPY-rate derivatives to angular acceleration. + + Returns: + A ``casadi.Function`` with signature + ``(quat[4], rpy_rates[3], rpy_rates_deriv[3]) -> ang_vel_deriv[3]`` + implementing ``ω̇ = Ẇ · ṙpy + W · r̈py``. + """ + qw = cs.MX.sym("qw") + qx = cs.MX.sym("qx") + qy = cs.MX.sym("qy") + qz = cs.MX.sym("qz") + quat = cs.vertcat(qx, qy, qz, qw) # Quaternions + rpy = cs_quat2euler(quat) + phi, theta = rpy[0], rpy[1] + phi_dot = cs.MX.sym("phi_dot") + theta_dot = cs.MX.sym("theta_dot") + psi_dot = cs.MX.sym("psi_dot") + rpy_rates = cs.vertcat(phi_dot, theta_dot, psi_dot) # Euler rates + phi_dot_dot = cs.MX.sym("phi_dot_dot") + theta_dot_dot = cs.MX.sym("theta_dot_dot") + psi_dot_dot = cs.MX.sym("psi_dot_dot") + rpy_rates_deriv = cs.vertcat(phi_dot_dot, theta_dot_dot, psi_dot_dot) # Euler rates derivative + + row1 = cs.horzcat(0, 0, -cs.cos(theta) * theta_dot) + row2 = cs.horzcat( + 0, + -cs.sin(phi) * phi_dot, + cs.cos(phi) * phi_dot * cs.cos(theta) - cs.sin(phi) * cs.sin(theta) * theta_dot, + ) + row3 = cs.horzcat( + 0, + -cs.cos(phi) * phi_dot, + -cs.sin(phi) * phi_dot * cs.cos(theta) - cs.cos(phi) * cs.sin(theta) * theta_dot, + ) + + W_dot = cs.vertcat(row1, row2, row3) + ang_vel_deriv = W_dot @ rpy_rates + cs_rpy_rates2ang_vel(quat, rpy_rates_deriv) + + return cs.Function("cs_ang_vel2rpy_rates", [quat, rpy_rates, rpy_rates_deriv], [ang_vel_deriv]) + + +cs_rpy_rates_deriv2ang_vel_deriv = create_cs_rpy_rates_deriv2ang_vel_deriv() + + +def create_cs_quat2matrix() -> cs.Function: + """Generates a casadi numeric function from the cs_quat2matrix function.""" + qw = cs.MX.sym("qw") + qx = cs.MX.sym("qx") + qy = cs.MX.sym("qy") + qz = cs.MX.sym("qz") + quat = cs.vertcat(qx, qy, qz, qw) + matrix = cs_quat2matrix(quat) + return cs.Function("cs_quat2matrix", [quat], [matrix]) + + +cs_quat2matrix_func = create_cs_quat2matrix() + + +def create_cs_rpy2matrix() -> cs.Function: + """Generates a casadi numeric function from the cs_rpy2matrix function.""" + roll = cs.MX.sym("roll") + pitch = cs.MX.sym("pitch") + yaw = cs.MX.sym("yaw") + rpy = cs.vertcat(roll, pitch, yaw) + matrix = cs_rpy2matrix(rpy) + return cs.Function("cs_rpy2matrix", [rpy], [matrix]) + + +cs_rpy2matrix_func = create_cs_rpy2matrix() diff --git a/crazyflow/envs/drone_env.py b/crazyflow/envs/drone_env.py index 7efebbc..c5ed6ab 100644 --- a/crazyflow/envs/drone_env.py +++ b/crazyflow/envs/drone_env.py @@ -1,39 +1,39 @@ import warnings from functools import partial -from typing import Callable, Literal +from typing import Callable import jax import jax.numpy as jnp import numpy as np -from drone_controllers.core import load_params -from drone_controllers.mellinger import force_torque2rotor_vel from gymnasium import spaces from gymnasium.vector import AutoresetMode, VectorEnv from gymnasium.vector.utils import batch_space from jax import Array from numpy.typing import NDArray -from crazyflow.control.control import Control +from crazyflow.control import Control +from crazyflow.control.core import load_params +from crazyflow.control.mellinger import force_torque2rotor_vel +from crazyflow.dynamics import Dynamics from crazyflow.sim import Sim from crazyflow.sim.data import SimData -from crazyflow.sim.physics import Physics from crazyflow.sim.pipeline import append_fn from crazyflow.utils import leaf_replace -def action_space(control_type: Control, drone_model: str) -> spaces.Box: +def action_space(control_type: Control, drone: str) -> spaces.Box: """Select the appropriate action space for a given control type. Args: control_type: The desired control mode. - drone_model: Drone model of the environment. + drone: Drone of the environment. Returns: The action space. """ match control_type: case Control.attitude: - params = load_params(force_torque2rotor_vel, drone_model) + params = load_params(force_torque2rotor_vel, drone) thrust_min, thrust_max = params["thrust_min"] * 4, params["thrust_max"] * 4 return spaces.Box( np.array([-np.pi / 2, -np.pi / 2, -np.pi / 2, thrust_min], dtype=np.float32), @@ -64,8 +64,8 @@ def __init__( *, num_envs: int = 1, max_episode_time: float = 10.0, - physics: Literal["so_rpy", "first_principles"] | Physics = Physics.so_rpy, - drone_model: str = "cf2x_L250", + dynamics: Dynamics = Dynamics.so_rpy, + drone: str = "cf2x_L250", freq: int = 500, device: str = "cpu", reset_randomization: Callable[[SimData, Array], SimData] | None = None, @@ -75,8 +75,8 @@ def __init__( Args: num_envs: The number of environments to run in parallel. max_episode_time: The time horizon after which episodes are truncated (s). - physics: The crazyflow physics simulation model. - drone_model: Drone model of the environment. + dynamics: The crazyflow dynamics. + drone: Drone of the environment. freq: The frequency at which the environment is run. device: The device of the environment and the simulation. reset_randomization: A function that randomizes the initial state of the simulation. If @@ -86,12 +86,10 @@ def __init__( self.device = jax.devices(device)[0] self.freq = freq self.max_episode_time = max_episode_time - assert Physics(physics) in Physics, f"Invalid physics type {physics}" + assert Dynamics(dynamics) in Dynamics, f"Invalid dynamics type {dynamics}" # Initialize the simulation - self.sim = Sim( - n_worlds=num_envs, n_drones=1, drone_model=drone_model, device=device, physics=physics - ) + self.sim = Sim(n_worlds=num_envs, n_drones=1, drone=drone, device=device, dynamics=dynamics) assert self.sim.freq >= self.sim.control_freq, "Sim freq must be higher than control freq" if not self.sim.freq % self.freq == 0: # We can handle other cases, but it's not recommended @@ -109,7 +107,7 @@ def __init__( self._marked_for_reset = jnp.zeros((self.sim.n_worlds), dtype=jnp.bool_, device=self.device) # Define action and observation spaces - self.single_action_space = action_space(self.sim.control, self.sim.drone_model) + self.single_action_space = action_space(self.sim.control, self.sim.drone) self.action_space = batch_space(self.single_action_space, self.sim.n_worlds) self.single_observation_space = spaces.Dict( { @@ -223,6 +221,5 @@ def _reset_randomization(data: SimData, _: SimData, mask: Array) -> SimData: pos = jax.random.uniform(key=pos_key, shape=shape, minval=pos_min, maxval=pos_max) # Sample initial velocity vel = jax.random.uniform(key=vel_key, shape=shape, minval=-1.0, maxval=1.0) - # Setting initial ryp_rate when using physics.sys_id will not have an impact, so we skip it data = data.replace(states=leaf_replace(data.states, mask, pos=pos, vel=vel)) return data diff --git a/crazyflow/envs/figure_8_env.py b/crazyflow/envs/figure_8_env.py index 1416bed..ba862c1 100644 --- a/crazyflow/envs/figure_8_env.py +++ b/crazyflow/envs/figure_8_env.py @@ -1,5 +1,3 @@ -from typing import Literal - import jax import jax.numpy as jnp import numpy as np @@ -7,9 +5,9 @@ from gymnasium.vector.utils import batch_space from jax import Array +from crazyflow.dynamics import Dynamics from crazyflow.envs.drone_env import DroneEnv from crazyflow.sim.data import SimData -from crazyflow.sim.physics import Physics from crazyflow.sim.visualize import draw_line, draw_points from crazyflow.utils import leaf_replace @@ -30,8 +28,8 @@ def __init__( *, num_envs: int = 1, max_episode_time: float = 10.0, - physics: Literal["so_rpy", "first_principles"] | Physics = Physics.so_rpy, - drone_model: str = "cf2x_L250", + dynamics: Dynamics = Dynamics.so_rpy, + drone: str = "cf2x_L250", freq: int = 500, device: str = "cpu", ): @@ -43,16 +41,16 @@ def __init__( trajectory_time: Total time for completing the figure-eight trajectory in seconds. num_envs: Number of environments to run in parallel. max_episode_time: Maximum episode time in seconds. - physics: Physics backend to use. - drone_model: Drone model of the environment. + dynamics: Dynamics backend to use. + drone: Drone of the environment. freq: Frequency of the simulation. device: Device to use for the simulation. """ super().__init__( num_envs=num_envs, max_episode_time=max_episode_time, - physics=physics, - drone_model=drone_model, + dynamics=dynamics, + drone=drone, freq=freq, device=device, ) diff --git a/crazyflow/envs/landing_env.py b/crazyflow/envs/landing_env.py index ad37b39..3822445 100644 --- a/crazyflow/envs/landing_env.py +++ b/crazyflow/envs/landing_env.py @@ -1,5 +1,3 @@ -from typing import Literal - import jax import jax.numpy as jnp import mujoco @@ -8,9 +6,9 @@ from gymnasium.vector.utils import batch_space from jax import Array +from crazyflow.dynamics import Dynamics from crazyflow.envs.drone_env import DroneEnv from crazyflow.sim.data import SimState -from crazyflow.sim.physics import Physics class LandingEnv(DroneEnv): @@ -20,14 +18,14 @@ def __init__( self, num_envs: int = 1, max_episode_time: float = 10.0, - physics: Literal["so_rpy", "first_principles"] | Physics = Physics.so_rpy, + dynamics: Dynamics = Dynamics.so_rpy, freq: int = 500, device: str = "cpu", ): super().__init__( num_envs=num_envs, max_episode_time=max_episode_time, - physics=physics, + dynamics=dynamics, freq=freq, device=device, ) diff --git a/crazyflow/envs/reach_pos_env.py b/crazyflow/envs/reach_pos_env.py index be27dc1..a080b77 100644 --- a/crazyflow/envs/reach_pos_env.py +++ b/crazyflow/envs/reach_pos_env.py @@ -1,5 +1,4 @@ from functools import partial -from typing import Literal import jax import jax.numpy as jnp @@ -9,9 +8,9 @@ from gymnasium.vector.utils import batch_space from jax import Array +from crazyflow.dynamics import Dynamics from crazyflow.envs.drone_env import DroneEnv from crazyflow.sim.data import SimData -from crazyflow.sim.physics import Physics from crazyflow.utils import leaf_replace @@ -26,7 +25,7 @@ def __init__( vel_max: float = 1.0, num_envs: int = 1, max_episode_time: float = 10.0, - physics: Literal["so_rpy", "first_principles"] | Physics = Physics.so_rpy, + dynamics: Dynamics = Dynamics.so_rpy, freq: int = 500, device: str = "cpu", ): @@ -38,7 +37,7 @@ def __init__( super().__init__( num_envs=num_envs, max_episode_time=max_episode_time, - physics=physics, + dynamics=dynamics, freq=freq, device=device, reset_randomization=reset_randomization, diff --git a/crazyflow/envs/reach_vel_env.py b/crazyflow/envs/reach_vel_env.py index 52cdec5..082e80a 100644 --- a/crazyflow/envs/reach_vel_env.py +++ b/crazyflow/envs/reach_vel_env.py @@ -1,5 +1,3 @@ -from typing import Literal - import jax import jax.numpy as jnp import numpy as np @@ -7,8 +5,8 @@ from gymnasium.vector.utils import batch_space from jax import Array +from crazyflow.dynamics import Dynamics from crazyflow.envs.drone_env import DroneEnv -from crazyflow.sim.physics import Physics class ReachVelEnv(DroneEnv): @@ -19,14 +17,14 @@ def __init__( *, num_envs: int = 1, max_episode_time: float = 10.0, - physics: Literal["so_rpy", "first_principles"] | Physics = Physics.so_rpy, + dynamics: Dynamics = Dynamics.so_rpy, freq: int = 500, device: str = "cpu", ): super().__init__( num_envs=num_envs, max_episode_time=max_episode_time, - physics=physics, + dynamics=dynamics, freq=freq, device=device, ) diff --git a/crazyflow/sim/__init__.py b/crazyflow/sim/__init__.py index e5ad0d2..6a15efa 100644 --- a/crazyflow/sim/__init__.py +++ b/crazyflow/sim/__init__.py @@ -1,5 +1,4 @@ -from crazyflow.sim.physics import Physics +from crazyflow.dynamics import Dynamics from crazyflow.sim.sim import Sim -from crazyflow.sim.symbolic import symbolic_from_sim -__all__ = ["Sim", "Physics", "symbolic_from_sim"] +__all__ = ["Sim", "Dynamics"] diff --git a/crazyflow/sim/data.py b/crazyflow/sim/data.py index f778840..6872853 100644 --- a/crazyflow/sim/data.py +++ b/crazyflow/sim/data.py @@ -13,13 +13,11 @@ MellingerForceTorqueData, MellingerStateData, ) -from crazyflow.sim.physics import ( - FirstPrinciplesData, - Physics, - SoRpyData, - SoRpyRotorData, - SoRpyRotorDragData, -) +from crazyflow.dynamics import Dynamics +from crazyflow.dynamics.first_principles import Params as FirstPrinciplesParams +from crazyflow.dynamics.so_rpy import Params as SoRpyParams +from crazyflow.dynamics.so_rpy_rotor import Params as SoRpyRotorParams +from crazyflow.dynamics.so_rpy_rotor_drag import Params as SoRpyRotorDragParams @dataclass @@ -31,7 +29,7 @@ class SimState: vel: Array # (N, M, 3) """Velocity of the drone's center of mass in the world frame.""" ang_vel: Array # (N, M, 3) - """Angular velocity of the drone's center of mass in the world frame.""" + """Angular velocity of the drone in the body frame.""" force: Array # (N, M, 3) # CoM force """Force applied to the drone's center of mass in the world frame.""" torque: Array # (N, M, 3) # CoM torque @@ -118,7 +116,7 @@ def create( n_worlds: int, n_drones: int, control: Control, - drone_model: str, + drone: str, state_freq: int | None, attitude_freq: int | None, force_torque_freq: int | None, @@ -128,14 +126,12 @@ def create( rotor_vel = jnp.zeros((n_worlds, n_drones, 4), device=device) match control: case Control.state: - state = MellingerStateData.create( - n_worlds, n_drones, state_freq, drone_model, device - ) + state = MellingerStateData.create(n_worlds, n_drones, state_freq, drone, device) attitude = MellingerAttitudeData.create( - n_worlds, n_drones, attitude_freq, drone_model, device + n_worlds, n_drones, attitude_freq, drone, device ) force_torque = MellingerForceTorqueData.create( - n_worlds, n_drones, force_torque_freq, drone_model, device + n_worlds, n_drones, force_torque_freq, drone, device ) return SimControls( mode=control, @@ -146,10 +142,10 @@ def create( ) case Control.attitude: attitude = attitude = MellingerAttitudeData.create( - n_worlds, n_drones, attitude_freq, drone_model, device + n_worlds, n_drones, attitude_freq, drone, device ) force_torque = MellingerForceTorqueData.create( - n_worlds, n_drones, force_torque_freq, drone_model, device + n_worlds, n_drones, force_torque_freq, drone, device ) return SimControls( mode=control, @@ -160,7 +156,7 @@ def create( ) case Control.force_torque: force_torque = MellingerForceTorqueData.create( - n_worlds, n_drones, force_torque_freq, drone_model, device + n_worlds, n_drones, force_torque_freq, drone, device ) return SimControls( mode=control, @@ -189,20 +185,20 @@ class SimParams(typing.Protocol): @staticmethod def create( - n_worlds: int, n_drones: int, physics: Physics, drone_model: str, device: Device + n_worlds: int, n_drones: int, dynamics: Dynamics, drone: str, device: Device ) -> SimParams: """Create a default set of parameters for the simulation.""" - match physics: - case Physics.first_principles: - return FirstPrinciplesData.create(n_worlds, n_drones, drone_model, device) - case Physics.so_rpy: - return SoRpyData.create(n_worlds, n_drones, drone_model, device) - case Physics.so_rpy_rotor: - return SoRpyRotorData.create(n_worlds, n_drones, drone_model, device) - case Physics.so_rpy_rotor_drag: - return SoRpyRotorDragData.create(n_worlds, n_drones, drone_model, device) + match dynamics: + case Dynamics.first_principles: + return FirstPrinciplesParams.create(n_worlds, n_drones, drone, device) + case Dynamics.so_rpy: + return SoRpyParams.create(n_worlds, n_drones, drone, device) + case Dynamics.so_rpy_rotor: + return SoRpyRotorParams.create(n_worlds, n_drones, drone, device) + case Dynamics.so_rpy_rotor_drag: + return SoRpyRotorDragParams.create(n_worlds, n_drones, drone, device) case _: - raise ValueError(f"Physics mode {physics} not implemented") + raise ValueError(f"Dynamics mode {dynamics} not implemented") @dataclass @@ -222,7 +218,7 @@ class SimCore: rng_key: Array # (N, 1) """Random number generator key for the simulation.""" mjx_synced: Array # (1,) - """Whether the simulation data is synchronized with the MuJoCo model.""" + """Whether the simulation data is synchronized with the MuJoCo mjx_data.""" @staticmethod def create( diff --git a/crazyflow/sim/functional.py b/crazyflow/sim/functional.py index 4142c19..354af0b 100644 --- a/crazyflow/sim/functional.py +++ b/crazyflow/sim/functional.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from crazyflow.control import Control -from crazyflow.control.control import controllable as _controllable +from crazyflow.control.core import controllable as _controllable from crazyflow.utils import to_device if TYPE_CHECKING: @@ -28,10 +28,10 @@ def state_control(data: SimData, controls: Array) -> SimData: def attitude_control(data: SimData, controls: Array) -> SimData: """Attitude control function. - We need to stage the attitude controls because the sys_id physics mode operates directly on + We need to stage the attitude controls because the so_rpy dynamics mode operates directly on the attitude controls. If we were to directly update the controls, this would effectively - bypass the control frequency and run the attitude controller at the physics update rate. By - staging the controls, we ensure that the physics module sees the old controls until the + bypass the control frequency and run the attitude controller at the dynamics update rate. By + staging the controls, we ensure that the dynamics module sees the old controls until the controller updates at its correct frequency. """ assert data.controls.mode == Control.attitude, f"control type {data.controls.mode} not enabled" diff --git a/crazyflow/sim/integration.py b/crazyflow/sim/integration.py index bef3660..32f20b2 100644 --- a/crazyflow/sim/integration.py +++ b/crazyflow/sim/integration.py @@ -138,7 +138,7 @@ def _integrate( dt: The time step to integrate over. Returns: - The next position, quaternion, velocity, and roll, pitch, and yaw rates of the drone. + The next position, quaternion, velocity, angular velocity, and rotor velocity of the drone. """ next_pos = pos + dpos * dt # Prevent NaN gradients by setting extremely small rotations to 0. This should not be necessary @@ -183,7 +183,7 @@ def _integrate_symplectic( dt: The time step to integrate over. Returns: - The next position, quaternion, velocity, and roll, pitch, and yaw rates of the drone. + The next position, quaternion, velocity, angular velocity, and rotor velocity of the drone. """ next_vel = vel + dvel * dt next_ang_vel = ang_vel + dang_vel * dt diff --git a/crazyflow/sim/physics.py b/crazyflow/sim/physics.py deleted file mode 100644 index 1f9e873..0000000 --- a/crazyflow/sim/physics.py +++ /dev/null @@ -1,322 +0,0 @@ -"""Physics models for the simulation.""" - -from __future__ import annotations - -from enum import Enum -from typing import TYPE_CHECKING - -import jax -import jax.numpy as jnp -from drone_models.core import load_params -from drone_models.first_principles import dynamics as first_principles_dynamics -from drone_models.so_rpy import dynamics as so_rpy_dynamics -from drone_models.so_rpy_rotor import dynamics as so_rpy_rotor_dynamics -from drone_models.so_rpy_rotor_drag import dynamics as so_rpy_rotor_drag_dynamics -from flax.struct import dataclass -from jax import Array - -if TYPE_CHECKING: - from jax import Device - - from crazyflow.sim.data import SimData - - -class Physics(str, Enum): - """Physics mode for the simulation.""" - - first_principles = "first_principles" - so_rpy = "so_rpy" - so_rpy_rotor = "so_rpy_rotor" - so_rpy_rotor_drag = "so_rpy_rotor_drag" - default = first_principles - - -@dataclass -class FirstPrinciplesData: - mass: Array # (N, M, 1) - """Mass of the drone.""" - L: Array # (N, M, 1) - """Arm length of the drone.""" - prop_inertia: Array # (N, M, 1) - """Inertia of the propeller.""" - gravity_vec: Array # (N, M, 3) - """Gravity vector of the drone.""" - J: Array # (N, M, 3, 3) - """Inertia matrix of the drone.""" - J_inv: Array # (N, M, 3, 3) - """Inverse of the inertia matrix of the drone.""" - rpm2thrust: Array # (N, M, 1) - """Force constant of the drone.""" - rpm2torque: Array # (N, M, 1) - """Torque constant of the drone.""" - mixing_matrix: Array # (N, M, 3, 4) - """Mixing matrix of the drone.""" - drag_matrix: Array # (N, M, 3, 3) - """Drag matrix of the drone.""" - rotor_dyn_coef: Array # (N, M, 4) - """Rotor speed dynamics time constant of the drone.""" - - @staticmethod - def create( - n_worlds: int, n_drones: int, drone_model: str, device: Device - ) -> FirstPrinciplesData: - """Create a default set of parameters for the simulation.""" - p = load_params("first_principles", drone_model) - J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) - return FirstPrinciplesData( - mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), - L=jnp.asarray(p["L"], device=device), - prop_inertia=jnp.asarray(p["prop_inertia"], device=device), - gravity_vec=jnp.asarray(p["gravity_vec"], device=device), - J=J, - J_inv=jnp.linalg.inv(J), - rpm2thrust=jnp.asarray(p["rpm2thrust"], device=device), - rpm2torque=jnp.asarray(p["rpm2torque"], device=device), - mixing_matrix=jnp.asarray(p["mixing_matrix"], device=device), - drag_matrix=jnp.asarray(p["drag_matrix"], device=device), - rotor_dyn_coef=jnp.asarray(p["rotor_dyn_coef"], device=device), - ) - - -def first_principles_physics(data: SimData) -> SimData: - """Compute the forces and torques from the first principle physics model.""" - params: FirstPrinciplesData = data.params - vel, _, acc, ang_acc, rotor_acc = first_principles_dynamics( - pos=data.states.pos, - quat=data.states.quat, - vel=data.states.vel, - ang_vel=data.states.ang_vel, - cmd=data.controls.rotor_vel, - rotor_vel=data.states.rotor_vel, - dist_f=data.states.force, - dist_t=data.states.torque, - mass=params.mass, - L=params.L, - prop_inertia=params.prop_inertia, - gravity_vec=params.gravity_vec, - J=params.J, - J_inv=params.J_inv, - rpm2thrust=params.rpm2thrust, - rpm2torque=params.rpm2torque, - mixing_matrix=params.mixing_matrix, - drag_matrix=params.drag_matrix, - rotor_dyn_coef=params.rotor_dyn_coef, - ) - states_deriv = data.states_deriv.replace( - vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc, rotor_acc=rotor_acc - ) - return data.replace(states_deriv=states_deriv) - - -@dataclass -class SoRpyData: - mass: Array # (N, M, 1) - """Mass of the drone.""" - gravity_vec: Array # (N, M, 3) - """Gravity vector of the drone.""" - J: Array # (N, M, 3, 3) - """Inertia matrix of the drone.""" - J_inv: Array # (N, M, 3, 3) - """Inverse of the inertia matrix of the drone.""" - acc_coef: Array # (N, M, 1) - """Coefficient for the acceleration.""" - cmd_f_coef: Array # (N, M, 1) - """Coefficient for the collective thrust.""" - rpy_coef: Array # (N, M, 1) - """Coefficient for the roll pitch yaw dynamics.""" - rpy_rates_coef: Array # (N, M, 1) - """Coefficient for the roll pitch yaw rates dynamics.""" - cmd_rpy_coef: Array # (N, M, 1) - """Coefficient for the roll pitch yaw command dynamics.""" - - @staticmethod - def create(n_worlds: int, n_drones: int, drone_model: str, device: Device) -> SoRpyData: - """Create a default set of parameters for the simulation.""" - p = load_params("so_rpy", drone_model) - J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) - return SoRpyData( - mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), - gravity_vec=jnp.asarray(p["gravity_vec"], device=device), - J=J, - J_inv=jnp.linalg.inv(J), - acc_coef=jnp.asarray(p["acc_coef"], device=device), - cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), - rpy_coef=jnp.asarray(p["rpy_coef"], device=device), - rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), - cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), - ) - - -def so_rpy_physics(data: SimData) -> SimData: - """Compute the forces and torques from the so_rpy physics model.""" - params: SoRpyData = data.params - vel, _, acc, ang_acc, _ = so_rpy_dynamics( - pos=data.states.pos, - quat=data.states.quat, - vel=data.states.vel, - ang_vel=data.states.ang_vel, - cmd=data.controls.attitude.cmd, - dist_f=data.states.force, - dist_t=data.states.torque, - mass=params.mass, - gravity_vec=params.gravity_vec, - J=params.J, - J_inv=params.J_inv, - acc_coef=params.acc_coef, - cmd_f_coef=params.cmd_f_coef, - rpy_coef=params.rpy_coef, - rpy_rates_coef=params.rpy_rates_coef, - cmd_rpy_coef=params.cmd_rpy_coef, - ) - states_deriv = data.states_deriv.replace( - vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc - ) - return data.replace(states_deriv=states_deriv) - - -@dataclass -class SoRpyRotorData: - mass: Array # (N, M, 1) - """Mass of the drone.""" - gravity_vec: Array # (N, M, 3) - """Gravity vector of the drone.""" - J: Array # (N, M, 3, 3) - """Inertia matrix of the drone.""" - J_inv: Array # (N, M, 3, 3) - """Inverse of the inertia matrix of the drone.""" - thrust_time_coef: Array # (N, M, 1) - """Rotor coefficient of the drone.""" - acc_coef: Array # (N, M, 1) - """Acceleration coefficient of the drone.""" - cmd_f_coef: Array # (N, M, 1) - """Collective thrust coefficient of the drone.""" - rpy_coef: Array # (N, M, 1) - """Roll pitch yaw coefficient of the drone.""" - rpy_rates_coef: Array # (N, M, 1) - """Roll pitch yaw rates coefficient of the drone.""" - cmd_rpy_coef: Array # (N, M, 1) - """Roll pitch yaw command coefficient of the drone.""" - - @staticmethod - def create(n_worlds: int, n_drones: int, drone_model: str, device: Device) -> SoRpyRotorData: - """Create a default set of parameters for the simulation.""" - p = load_params("so_rpy_rotor", drone_model) - J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) - return SoRpyRotorData( - mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), - gravity_vec=jnp.asarray(p["gravity_vec"], device=device), - J=J, - J_inv=jnp.linalg.inv(J), - thrust_time_coef=jnp.asarray(p["thrust_time_coef"], device=device), - acc_coef=jnp.asarray(p["acc_coef"], device=device), - cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), - rpy_coef=jnp.asarray(p["rpy_coef"], device=device), - rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), - cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), - ) - - -def so_rpy_rotor_physics(data: SimData) -> SimData: - """Compute the forces and torques from the so_rpy_rotor physics model.""" - params: SoRpyRotorData = data.params - vel, _, acc, ang_acc, rotor_acc = so_rpy_rotor_dynamics( - pos=data.states.pos, - quat=data.states.quat, - vel=data.states.vel, - ang_vel=data.states.ang_vel, - rotor_vel=data.states.rotor_vel, - cmd=data.controls.attitude.cmd, - dist_f=data.states.force, - dist_t=data.states.torque, - mass=params.mass, - gravity_vec=params.gravity_vec, - J=params.J, - J_inv=params.J_inv, - thrust_time_coef=params.thrust_time_coef, - acc_coef=params.acc_coef, - cmd_f_coef=params.cmd_f_coef, - rpy_coef=params.rpy_coef, - rpy_rates_coef=params.rpy_rates_coef, - cmd_rpy_coef=params.cmd_rpy_coef, - ) - states_deriv = data.states_deriv.replace( - vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc, rotor_acc=rotor_acc - ) - return data.replace(states_deriv=states_deriv) - - -@dataclass -class SoRpyRotorDragData: - mass: Array # (N, M, 1) - """Mass of the drone.""" - gravity_vec: Array # (N, M, 3) - """Gravity vector of the drone.""" - J: Array # (N, M, 3, 3) - """Inertia matrix of the drone.""" - J_inv: Array # (N, M, 3, 3) - """Inverse of the inertia matrix of the drone.""" - thrust_time_coef: Array # (N, M, 1) - """Rotor coefficient of the drone.""" - acc_coef: Array # (N, M, 1) - """Acceleration coefficient of the drone.""" - cmd_f_coef: Array # (N, M, 1) - """Collective thrust coefficient of the drone.""" - rpy_coef: Array # (N, M, 1) - """Roll pitch yaw coefficient of the drone.""" - rpy_rates_coef: Array # (N, M, 1) - """Roll pitch yaw rates coefficient of the drone.""" - cmd_rpy_coef: Array # (N, M, 1) - """Roll pitch yaw command coefficient of the drone.""" - drag_matrix: Array # (N, M, 3, 3) - """Linear drag coefficient matrix of the drone.""" - - @staticmethod - def create( - n_worlds: int, n_drones: int, drone_model: str, device: Device - ) -> SoRpyRotorDragData: - """Create a default set of parameters for the simulation.""" - p = load_params("so_rpy_rotor_drag", drone_model) - J = jax.device_put(jnp.tile(p["J"][None, None, :, :], (n_worlds, n_drones, 1, 1)), device) - return SoRpyRotorDragData( - mass=jnp.full((n_worlds, n_drones, 1), p["mass"], device=device), - gravity_vec=jnp.asarray(p["gravity_vec"], device=device), - J=J, - J_inv=jnp.linalg.inv(J), - thrust_time_coef=jnp.asarray(p["thrust_time_coef"], device=device), - acc_coef=jnp.asarray(p["acc_coef"], device=device), - cmd_f_coef=jnp.asarray(p["cmd_f_coef"], device=device), - rpy_coef=jnp.asarray(p["rpy_coef"], device=device), - rpy_rates_coef=jnp.asarray(p["rpy_rates_coef"], device=device), - cmd_rpy_coef=jnp.asarray(p["cmd_rpy_coef"], device=device), - drag_matrix=jnp.asarray(p["drag_matrix"], device=device), - ) - - -def so_rpy_rotor_drag_physics(data: SimData) -> SimData: - """Compute the forces and torques from the so_rpy_rotor_drag physics model.""" - params: SoRpyRotorDragData = data.params - vel, _, acc, ang_acc, rotor_acc = so_rpy_rotor_drag_dynamics( - pos=data.states.pos, - quat=data.states.quat, - vel=data.states.vel, - ang_vel=data.states.ang_vel, - cmd=data.controls.attitude.cmd, - rotor_vel=data.states.rotor_vel, - dist_f=data.states.force, - dist_t=data.states.torque, - mass=params.mass, - gravity_vec=params.gravity_vec, - J=params.J, - J_inv=params.J_inv, - thrust_time_coef=params.thrust_time_coef, - acc_coef=params.acc_coef, - cmd_f_coef=params.cmd_f_coef, - rpy_coef=params.rpy_coef, - rpy_rates_coef=params.rpy_rates_coef, - cmd_rpy_coef=params.cmd_rpy_coef, - drag_matrix=params.drag_matrix, - ) - states_deriv = data.states_deriv.replace( - vel=vel, ang_vel=data.states.ang_vel, acc=acc, ang_acc=ang_acc, rotor_acc=rotor_acc - ) - return data.replace(states_deriv=states_deriv) diff --git a/crazyflow/sim/sensors.py b/crazyflow/sim/sensors.py index 823c027..806d046 100644 --- a/crazyflow/sim/sensors.py +++ b/crazyflow/sim/sensors.py @@ -36,7 +36,7 @@ def build_render_depth_fn( ) -> Callable[[Sim], Array]: """Build a depth renderer function for given camera and resolution. - Compiles the mjx model and rays directly into the rendering function for higher performance. The + Compiles the mjx_model and rays directly into the rendering function for higher performance. The returned function takes a Sim object as input and returns depth images. """ rays = _camera_rays(resolution=resolution, fov_y=jnp.pi / 4)[None, ...] diff --git a/crazyflow/sim/sim.py b/crazyflow/sim/sim.py index 2ac365e..6f253b9 100644 --- a/crazyflow/sim/sim.py +++ b/crazyflow/sim/sim.py @@ -5,44 +5,36 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar -import drone_models import jax import jax.numpy as jnp import mujoco import mujoco.mjx as mjx -from drone_controllers.mellinger import ( - attitude2force_torque, - force_torque2rotor_vel, - state2attitude, -) from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer from jax import Array, Device import crazyflow.sim.functional as F -from crazyflow.control.control import Control, controllable +from crazyflow.control import Control +from crazyflow.control.mellinger import ( + control_attitude2force_torque, + control_commit_attitude, + control_force_torque2rotor_vel, + control_state2attitude, +) +from crazyflow.dynamics import Dynamics +from crazyflow.dynamics.first_principles import sim_dynamics as first_principles_dynamics +from crazyflow.dynamics.so_rpy import sim_dynamics as so_rpy_dynamics +from crazyflow.dynamics.so_rpy_rotor import sim_dynamics as so_rpy_rotor_dynamics +from crazyflow.dynamics.so_rpy_rotor_drag import sim_dynamics as so_rpy_rotor_drag_dynamics from crazyflow.exception import ConfigError, NotInitializedError from crazyflow.sim.data import SimControls, SimCore, SimData, SimParams, SimState, SimStateDeriv from crazyflow.sim.integration import Integrator, euler, rk4, symplectic_euler -from crazyflow.sim.physics import ( - Physics, - first_principles_physics, - so_rpy_physics, - so_rpy_rotor_drag_physics, - so_rpy_rotor_physics, -) from crazyflow.sim.pipeline import append_fn -from crazyflow.utils import grid_2d, leaf_replace, pytree_replace +from crazyflow.utils import grid_2d, pytree_replace if TYPE_CHECKING: from mujoco.mjx import Data, Model from numpy.typing import NDArray - from crazyflow.control.mellinger import ( - MellingerAttitudeData, - MellingerForceTorqueData, - MellingerStateData, - ) - Params = ParamSpec("Params") # Represents arbitrary parameters Return = TypeVar("Return") # Represents the return type @@ -60,12 +52,23 @@ def wrapper(sim: Sim, *args: Any, **kwargs: Any) -> SimData: class Sim: + """Crazyflow simulation. + + Used both through its object-oriented methods (:meth:`step`, :meth:`reset`, the ``*_control`` + setters) and as the builder for the functional API in :mod:`crazyflow.sim.functional`, which + operates on the ``sim.data`` and pipelines constructed here. + + The simulation is always batched. Every quantity in ``sim.data`` has a leading + ``(n_worlds, n_drones, ...)`` shape, even for a single world and drone. ``n_worlds`` indexes + parallel copies of the scene and ``n_drones`` the drones within each world. + """ + def __init__( self, n_worlds: int = 1, n_drones: int = 1, - drone_model: str = "cf2x_L250", - physics: Physics = Physics.default, + drone: str = "cf2x_L250", + dynamics: Dynamics = Dynamics.default, control: Control = Control.default, integrator: Integrator = Integrator.default, freq: int = 500, @@ -77,16 +80,36 @@ def __init__( rng_key: int = 0, fused_mjx_model: bool = False, ): - assert Physics(physics) in Physics, f"Physics mode {physics} not implemented" + """Build the scene and the step and reset pipelines, and allocate the batched sim data. + + Args: + n_worlds: Number of parallel worlds to simulate. + n_drones: Number of drones per world. + drone: Name of the drone. + dynamics: Dynamics used to advance the drone state. + control: Control interface exposed to the user. + integrator: Integration scheme for the dynamics. + freq: Dynamics step frequency in Hz. + state_freq: Frequency in Hz at which the state controller runs. + attitude_freq: Frequency in Hz at which the attitude controller runs. + force_torque_freq: Frequency in Hz at which the force/torque controller runs. + device: Device to place the simulation data on (e.g. ``"cpu"`` or ``"gpu"``). + xml_path: Path to a custom scene XML. Defaults to ``crazyflow/scene.xml``. + rng_key: Seed for the JAX rng key. + fused_mjx_model: If True, use the ``drone_fused`` body whose visual geometry is fused + into a single mesh. This shrinks the MJX model and reduces its memory footprint at + the cost of visual detail. + """ + assert Dynamics(dynamics) in Dynamics, f"Dynamics mode {dynamics} not implemented" assert Control(control) in Control, f"Control mode {control} not implemented" - if physics != Physics.first_principles: + if dynamics != Dynamics.first_principles: if control in (Control.force_torque, Control.rotor_vel): - raise ConfigError(f"Control mode {control} requires first principles physics") + raise ConfigError(f"Control mode {control} requires first principles dynamics") if freq > 10_000 and not jax.config.jax_enable_x64: raise ConfigError("High frequency simulations require double precision mode") - self.physics = physics + self.dynamics = dynamics self.control = control - self.drone_model = drone_model + self.drone = drone self.integrator = integrator self.device = jax.devices(device)[0] self.n_worlds = n_worlds @@ -95,9 +118,9 @@ def __init__( self.max_visual_geom = 1000 # Initialize MuJoCo world and data + self.fused_mjx_model = fused_mjx_model self._xml_path = xml_path or Path(__file__).parents[1] / "scene.xml" - model_file_name = f"{drone_model}{'_fused' if fused_mjx_model else ''}.xml" - self.drone_path = Path(drone_models.__file__).parent / "data" / model_file_name + self.drone_path = Path(__file__).parents[1] / f"drones/{drone}.xml" self.spec = self.build_mjx_spec() self.mj_model, self.mj_data, self.mjx_model, self.mjx_data = self.build_mjx_model(self.spec) self.viewer: MujocoRenderer | None = None @@ -115,9 +138,9 @@ def __init__( # The ``select_xxx_fn`` methods return functions, not the results of calling those # functions. They act as factories that produce building blocks for the construction of our # simulation pipeline. - for fn in build_control_fns(self.control, self.physics): - append_fn(self.step_pipeline, fn) - integrate_fn = select_integrate_fn(self.integrator, select_physics_fn(self.physics)) + for name, fn in build_control_fns(self.control, self.dynamics): + append_fn(self.step_pipeline, fn, name=name) + integrate_fn = select_integrate_fn(self.integrator, select_dynamics_fn(self.dynamics)) append_fn(self.step_pipeline, integrate_fn, name="integration") append_fn(self.step_pipeline, increment_steps) # We never drop below -0.001 (drones can't pass through the floor). We use -0.001 to @@ -215,17 +238,18 @@ def close(self): self.viewer = None def build_mjx_spec(self) -> mujoco.MjSpec: - """Build the MuJoCo model specification for the simulation.""" + """Build the MuJoCo mjx_model specification for the simulation.""" assert self._xml_path.exists(), f"Model file {self._xml_path} does not exist" spec = mujoco.MjSpec.from_file(str(self._xml_path)) spec.option.timestep = 1 / self.freq spec.copy_during_attach = True drone_spec = mujoco.MjSpec.from_file(str(self.drone_path)) frame = spec.worldbody.add_frame(name="world") - if (drone_body := drone_spec.body("drone")) is None: + name = "drone_fused" if self.fused_mjx_model else "drone" + if (drone_body := drone_spec.body(name)) is None: raise ValueError("Drone body not found in drone spec") # Mocap bodies avoid the nv^2 cost of qM/qLD/efc_J. A single dummy slide joint keeps nv=1 so - # mjx.kinematics doesn't error on a zero-DOF model. + # mjx.kinematics doesn't error on a zero-DOF mjx_model. dummy = spec.worldbody.add_body() dummy.name = "_dummy" dummy.mass = 1e-6 @@ -362,8 +386,9 @@ def init_data( self, state_freq: int, attitude_freq: int, force_torque_freq: int, rng_key: Array ) -> SimData: """Initialize the simulation data.""" + drone_name = "drone_fused" if self.fused_mjx_model else "drone" drone_mocap_ids = [ - self.mj_model.body(f"drone:{i}").mocapid.item() for i in range(self.n_drones) + self.mj_model.body(f"{drone_name}:{i}").mocapid.item() for i in range(self.n_drones) ] N, D = self.n_worlds, self.n_drones data = SimData( @@ -373,13 +398,13 @@ def init_data( N, D, self.control, - self.drone_model, + self.drone, state_freq, attitude_freq, force_torque_freq, self.device, ), - params=SimParams.create(N, D, self.physics, self.drone_model, self.device), + params=SimParams.create(N, D, self.dynamics, self.drone, self.device), core=SimCore.create(self.freq, N, D, drone_mocap_ids, rng_key, self.device), ) if D > 1: # If multiple drones, arrange them in a grid @@ -440,55 +465,59 @@ def _step(data: SimData, n_steps: int) -> SimData: def build_control_fns( - control: Control, physics: Physics -) -> tuple[Callable[[SimData], SimData], ...]: - """Select the control functions for the given control mode. + control: Control, dynamics: Dynamics +) -> tuple[tuple[str, Callable[[SimData], SimData]], ...]: + """Select the named control stages for the given control mode. Note: - This function returns a tuple of functions, not a single function. The returned functions - are called in succession in the simulation pipeline. + Returns ``(name, fn)`` pairs, called in succession in the simulation pipeline. The names are + the stable pipeline stage identifiers used to insert, replace, or remove stages. """ + state = ("state_controller", control_state2attitude) + attitude = ("attitude_controller", control_attitude2force_torque) + force_torque = ("force_torque_controller", control_force_torque2rotor_vel) + commit_attitude = ("commit_attitude", control_commit_attitude) match control: case Control.state: - control_pipeline = (step_state_controller, step_attitude_controller) - if physics == Physics.first_principles: - control_pipeline = control_pipeline + (step_force_torque_controller,) + stages = (state, attitude) + if dynamics == Dynamics.first_principles: + stages = stages + (force_torque,) case Control.attitude: - if physics == Physics.first_principles: - control_pipeline = (step_attitude_controller, step_force_torque_controller) - elif physics in (Physics.so_rpy, Physics.so_rpy_rotor, Physics.so_rpy_rotor_drag): - control_pipeline = (commit_attitude_controller,) + if dynamics == Dynamics.first_principles: + stages = (attitude, force_torque) + elif dynamics in (Dynamics.so_rpy, Dynamics.so_rpy_rotor, Dynamics.so_rpy_rotor_drag): + stages = (commit_attitude,) else: - raise NotImplementedError(f"Control mode {control} not implemented for {physics}") + raise NotImplementedError(f"Control mode {control} not implemented for {dynamics}") case Control.force_torque: - control_pipeline = (step_force_torque_controller,) + stages = (force_torque,) case Control.rotor_vel: - control_pipeline = () + stages = () case _: raise NotImplementedError(f"Control mode {control} not implemented") - return control_pipeline + return stages -def select_physics_fn(physics: Physics) -> Callable[[SimData], SimData]: - """Select the physics function for the given physics mode.""" - match physics: - case Physics.first_principles: - return first_principles_physics - case Physics.so_rpy: - return so_rpy_physics - case Physics.so_rpy_rotor: - return so_rpy_rotor_physics - case Physics.so_rpy_rotor_drag: - return so_rpy_rotor_drag_physics +def select_dynamics_fn(dynamics: Dynamics) -> Callable[[SimData], SimData]: + """Select the dynamics function for the given dynamics mode.""" + match dynamics: + case Dynamics.first_principles: + return first_principles_dynamics + case Dynamics.so_rpy: + return so_rpy_dynamics + case Dynamics.so_rpy_rotor: + return so_rpy_rotor_dynamics + case Dynamics.so_rpy_rotor_drag: + return so_rpy_rotor_drag_dynamics case _: - raise NotImplementedError(f"Physics mode {physics} not implemented") + raise NotImplementedError(f"Dynamics mode {dynamics} not implemented") def select_integrate_fn( - integrator: Integrator, physics_fn: Callable[[SimData], SimData] + integrator: Integrator, dynamics_fn: Callable[[SimData], SimData] ) -> Callable[[SimData], SimData]: - """Select the integration function for the given physics and integrator mode.""" + """Select the integration function for the given dynamics and integrator mode.""" match integrator: case Integrator.euler: integrate_fn = euler @@ -499,7 +528,7 @@ def select_integrate_fn( case _: raise NotImplementedError(f"Integrator {integrator} not implemented") - return partial(integrate_fn, deriv_fn=physics_fn) + return partial(integrate_fn, deriv_fn=dynamics_fn) def reset(data: SimData, default_data: SimData, mask: Array | None = None) -> SimData: @@ -524,7 +553,7 @@ def contacts(geom_start: int, geom_count: int, data: Data) -> Array: @jax.jit def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimData, Data]: - """Synchronize the simulation data with the MuJoCo model.""" + """Synchronize the simulation data with the MuJoCo mjx_model.""" pos, quat = data.states.pos, data.states.quat quat_mjx = jnp.roll(quat, 1, axis=-1) # MuJoCo quat is [w, x, y, z], ours is [x, y, z, w] ids = data.core.drone_mocap_ids @@ -539,79 +568,6 @@ def sync_sim2mjx(data: SimData, mjx_data: Data, mjx_model: Model) -> tuple[SimDa return data, mjx_data -def step_state_controller(data: SimData) -> SimData: - """Compute the updated controls for the state controller.""" - states = data.states - state_ctrl: MellingerStateData = data.controls.state - assert state_ctrl is not None, "Using state controller without initialized data" - mask = controllable(data.core.steps, data.core.freq, state_ctrl.steps, state_ctrl.freq) - state_ctrl = leaf_replace(state_ctrl, mask, cmd=state_ctrl.staged_cmd) - rpyt, pos_err_i = state2attitude( - states.pos, - states.quat, - states.vel, - state_ctrl.cmd, - ctrl_errors=(state_ctrl.pos_err_i,), - ctrl_freq=state_ctrl.freq, - **state_ctrl.params, - ) - state_ctrl = leaf_replace(state_ctrl, mask, steps=data.core.steps, pos_err_i=pos_err_i) - attitude_ctrl = leaf_replace(data.controls.attitude, mask, staged_cmd=rpyt) - return data.replace(controls=data.controls.replace(state=state_ctrl, attitude=attitude_ctrl)) - - -def step_attitude_controller(data: SimData) -> SimData: - """Compute the updated controls for the attitude controller.""" - states = data.states - attitude_ctrl: MellingerAttitudeData = data.controls.attitude - assert attitude_ctrl is not None, "Using attitude controller without initialized data" - mask = controllable(data.core.steps, data.core.freq, attitude_ctrl.steps, attitude_ctrl.freq) - attitude_ctrl = leaf_replace(attitude_ctrl, mask, cmd=attitude_ctrl.staged_cmd) - force, torque, r_int_error = attitude2force_torque( - states.quat, - states.ang_vel, - attitude_ctrl.cmd, - ctrl_errors=(attitude_ctrl.r_int_error,), - ctrl_freq=attitude_ctrl.freq, - prev_ang_vel=attitude_ctrl.last_ang_vel, - **attitude_ctrl.params, - ) - attitude_ctrl = leaf_replace( - attitude_ctrl, - mask, - r_int_error=r_int_error, - last_ang_vel=states.ang_vel, - steps=data.core.steps, - ) - ft_ctrl = leaf_replace( - data.controls.force_torque, mask, staged_cmd=jnp.concat([force, torque], axis=-1) - ) - return data.replace( - states=states, controls=data.controls.replace(attitude=attitude_ctrl, force_torque=ft_ctrl) - ) - - -def commit_attitude_controller(data: SimData) -> SimData: - """Commit the staged attitude command to the controller setpoint.""" - attitude_ctrl: MellingerAttitudeData = data.controls.attitude - mask = controllable(data.core.steps, data.core.freq, attitude_ctrl.steps, attitude_ctrl.freq) - attitude_ctrl = leaf_replace(attitude_ctrl, mask, cmd=attitude_ctrl.staged_cmd) - return data.replace(controls=data.controls.replace(attitude=attitude_ctrl)) - - -def step_force_torque_controller(data: SimData) -> SimData: - """Compute the updated controls for the thrust controller.""" - ft_ctrl: MellingerForceTorqueData = data.controls.force_torque - assert ft_ctrl is not None, "Using force torque controller without initialized data" - mask = controllable(data.core.steps, data.core.freq, ft_ctrl.steps, ft_ctrl.freq) - ft_ctrl = leaf_replace(ft_ctrl, mask, cmd=ft_ctrl.staged_cmd) - rotor_vel = force_torque2rotor_vel( - ft_ctrl.cmd[..., [0]], ft_ctrl.cmd[..., 1:], **ft_ctrl.params - ) - ft_ctrl = leaf_replace(ft_ctrl, mask, steps=data.core.steps) - return data.replace(controls=data.controls.replace(rotor_vel=rotor_vel, force_torque=ft_ctrl)) - - def clip_floor_pos(data: SimData) -> SimData: """Clip the position of the drone to the floor.""" clip = data.states.pos[..., 2] < -0.001 diff --git a/crazyflow/sim/symbolic.py b/crazyflow/sim/symbolic.py deleted file mode 100644 index 4557072..0000000 --- a/crazyflow/sim/symbolic.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from drone_models import parametrize -from drone_models.first_principles import symbolic_dynamics as first_principles_symbolic_dynamics -from drone_models.so_rpy import symbolic_dynamics as so_rpy_symbolic_dynamics - -from crazyflow.sim.data import Control -from crazyflow.sim.physics import Physics - -if TYPE_CHECKING: - import casadi as cs - - from crazyflow.sim import Sim - - -def symbolic_from_sim( - sim: Sim, model_rotor_vel: bool = False, model_dist_f: bool = False, model_dist_t: bool = False -) -> tuple[cs.MX, cs.MX, cs.MX, cs.MX]: - """Create a symbolic model from a simulation object. - - Args: - sim: The simulation object. - model_rotor_vel: Flag to model the rotor velocity. - model_dist_f: Flag to model the distributed force. - model_dist_t: Flag to model the distributed torque. - - Returns: - The four symbolic expressions for X_dot, X, U, Y. - """ - if sim.control != Control.attitude: - raise ValueError("Symbolic model dynamics only support attitude control") - match sim.physics: - case Physics.first_principles: - return parametrize(first_principles_symbolic_dynamics, sim.drone_model)( - model_rotor_vel=model_rotor_vel, - model_dist_f=model_dist_f, - model_dist_t=model_dist_t, - ) - case Physics.so_rpy: - return parametrize(so_rpy_symbolic_dynamics, sim.drone_model)( - model_rotor_vel=model_rotor_vel, - model_dist_f=model_dist_f, - model_dist_t=model_dist_t, - ) - case _: - raise ValueError(f"Physics mode {sim.physics} not supported") diff --git a/crazyflow/utils.py b/crazyflow/utils.py index 88c56bd..30b5d1b 100644 --- a/crazyflow/utils.py +++ b/crazyflow/utils.py @@ -1,13 +1,19 @@ from __future__ import annotations +import inspect +from collections.abc import Mapping from functools import partial from pathlib import Path -from typing import TypeVar +from typing import TYPE_CHECKING, Any, Callable, ParamSpec, TypeVar import jax import jax.numpy as jnp +import numpy as np from jax import Array +if TYPE_CHECKING: + from types import ModuleType + def grid_2d(n: int, spacing: float = 1.0, center: Array | None = None) -> Array: """Generate a 2D grid of points.""" @@ -21,6 +27,8 @@ def grid_2d(n: int, spacing: float = 1.0, center: Array | None = None) -> Array: T = TypeVar("T") # PyTree type +P = ParamSpec("P") +R = TypeVar("R") def pytree_replace(tree: T, new_tree: T, mask: Array | None = None) -> T: @@ -80,3 +88,56 @@ def enable_cache( jax.config.update("jax_persistent_cache_min_compile_time_secs", min_compile_time_secs) if enable_xla_caches: jax.config.update("jax_persistent_cache_enable_xla_caches", "all") + + +def parametrize( + fn: Callable[P, R], + drone: str, + load_params: Callable[..., dict], + xp: ModuleType | None = None, + device: str | None = None, +) -> Callable[P, R]: + """Parametrize a function with the default parameters for a drone. + + Args: + fn: The function to parametrize. + drone: The drone to use. + load_params: The function to load the parameters for the given drone. + xp: The array API module to use. If not provided, numpy is used. + device: The device to use. If none, the device is inferred from the xp module. + + Returns: + The parametrized function with all keyword only arguments filled in. + """ + params = load_params(fn, drone, xp=xp, device=device) + xp = np if xp is None else xp + fn_params = inspect.signature(fn).parameters + fn_kwargs = {k for k, v in fn_params.items() if v.kind == inspect.Parameter.KEYWORD_ONLY} + kwargs = {k: xp.asarray(v, device=device) for k, v in params.items() if k in fn_kwargs} + return partial(fn, **kwargs) + + +def filter_to_signature(params: dict, fn: Callable) -> dict: + """Keep only the params accepted by ``fn``. + + Asserts that every keyword-only parameter of ``fn`` (the injectable params, as opposed to the + positional runtime inputs) is present in ``params``. + """ + sig = inspect.signature(fn).parameters + filtered = {k: v for k, v in params.items() if k in sig} + required = {k for k, p in sig.items() if p.kind == inspect.Parameter.KEYWORD_ONLY} + missing = required - filtered.keys() + assert not missing, f"Missing parameters for {fn.__name__}: {missing}" + return filtered + + +def to_xp(*args: Any, xp: ModuleType | None = None, device: Any = None) -> Any: + """Convert arrays, dicts etc recursively to the ``xp`` namespace and device.""" + xp = np if xp is None else xp + match args: + case [Mapping() as m]: + return {k: to_xp(v, xp=xp, device=device) for k, v in m.items()} + case [single]: + return xp.asarray(single, device=device) + case _: + return tuple(to_xp(a, xp=xp, device=device) for a in args) diff --git a/docs/api/index.md b/docs/api/index.md index 969e27c..4a502eb 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -6,14 +6,13 @@ This section is auto-generated from the Crazyflow source code using [mkdocstring | Module | Description | |---|---| -| `crazyflow.sim` | Core `Sim` class and physics pipeline | +| `crazyflow.sim` | Core `Sim` class and simulation pipeline | | `crazyflow.sim.pipeline` | `OrderedDict`-based pipeline helpers (`append_fn`, `insert_fn_before`, `replace_fn`, etc.) | | `crazyflow.sim.data` | `SimData`, `SimState`, `SimControls`, `SimParams`, `SimCore` pytrees | | `crazyflow.sim.functional` | Pure functional control API for use inside `jax.jit` | -| `crazyflow.sim.physics` | `Physics` enum and physics model implementations | +| `crazyflow.dynamics` | `Dynamics` enum and dynamics implementations | | `crazyflow.sim.integration` | `Integrator` enum, Euler, RK4, and symplectic Euler | | `crazyflow.sim.sensors` | Raycasting and sensor extraction utilities | -| `crazyflow.sim.symbolic` | CasADi symbolic model API | | `crazyflow.control` | `Control` enum | | `crazyflow.control.mellinger` | Mellinger controller data and parameters | | `crazyflow.envs` | Gymnasium vectorized environments | diff --git a/docs/examples/index.md b/docs/examples/index.md index 7cd4774..f4baaf7 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -30,7 +30,7 @@ Commanding roll, pitch, yaw, and collective thrust directly. This level bypasses ## Sampling-based MPC -A sampling-based model predictive controller tracks a Lissajous curve while avoiding a grid of obstacles. It rolls out thousands of candidate control sequences in parallel using a reduced dynamics model, then applies the first action from a cost-weighted update of the best samples. The controller automatically uses a GPU when one is available and lowers the sample count on CPU. +A sampling-based model predictive controller tracks a Lissajous curve while avoiding a grid of obstacles. It rolls out thousands of candidate control sequences in parallel using identified dynamics, then applies the first action from a cost-weighted update of the best samples. The controller automatically uses a GPU when one is available and lowers the sample count on CPU. ```bash python examples/control/sampling.py @@ -64,7 +64,7 @@ python examples/plugins/randomize.py ## Disturbance injection -Inserting a random external force and torque into the step pipeline. The disturbance fires on every physics tick, so the drone fights wind-like perturbations. +Inserting a random external force and torque into the step pipeline. The disturbance fires on every dynamics tick, so the drone fights wind-like perturbations. ```{ .python notest } --8<-- "examples/plugins/disturbance.py" diff --git a/docs/gen_ref_pages.py b/docs/gen_ref_pages.py index 73a4324..c42ec0a 100644 --- a/docs/gen_ref_pages.py +++ b/docs/gen_ref_pages.py @@ -34,6 +34,10 @@ with mkdocs_gen_files.open(full_doc_path, "w") as fd: ident = ".".join(parts) fd.write(f"::: {ident}\n") + # Dynamics is re-exported by crazyflow.sim for convenience but documented under + # crazyflow.dynamics. Filter it out here so it is not rendered on both pages. + if ident == "crazyflow.sim": + fd.write(' options:\n filters: ["!^Dynamics$"]\n') mkdocs_gen_files.set_edit_path(full_doc_path, path) @@ -44,14 +48,23 @@ * [sim](crazyflow/sim/index.md) * [sim.data](crazyflow/sim/data.md) * [sim.functional](crazyflow/sim/functional.md) - * [sim.physics](crazyflow/sim/physics.md) * [sim.integration](crazyflow/sim/integration.md) * [sim.sensors](crazyflow/sim/sensors.md) - * [sim.symbolic](crazyflow/sim/symbolic.md) * [sim.visualize](crazyflow/sim/visualize.md) +* Dynamics + * [dynamics](crazyflow/dynamics/index.md) + * [dynamics.core](crazyflow/dynamics/core.md) + * [dynamics.first_principles](crazyflow/dynamics/first_principles/index.md) + * [dynamics.so_rpy](crazyflow/dynamics/so_rpy/index.md) + * [dynamics.so_rpy_rotor](crazyflow/dynamics/so_rpy_rotor/index.md) + * [dynamics.so_rpy_rotor_drag](crazyflow/dynamics/so_rpy_rotor_drag/index.md) + * [dynamics.symbols](crazyflow/dynamics/symbols.md) * Control * [control](crazyflow/control/index.md) - * [control.mellinger](crazyflow/control/mellinger.md) + * [control.core](crazyflow/control/core.md) + * [control.transform](crazyflow/control/transform.md) + * [control.mellinger](crazyflow/control/mellinger/index.md) + * [control.mellinger.control](crazyflow/control/mellinger/control.md) * Environments * [envs](crazyflow/envs/index.md) * [envs.drone_env](crazyflow/envs/drone_env.md) diff --git a/docs/get-started/installation.md b/docs/get-started/installation.md index ffac14e..f51e5fe 100644 --- a/docs/get-started/installation.md +++ b/docs/get-started/installation.md @@ -17,7 +17,7 @@ Select your installation method from the tabs below, then read the notes under e === "pixi" ```bash - git clone --recurse-submodules git@github.com:learnsyslab/crazyflow.git + git clone https://github.com/learnsyslab/crazyflow.git cd crazyflow pixi shell ``` @@ -25,7 +25,7 @@ Select your installation method from the tabs below, then read the notes under e === "pixi + tests" ```bash - git clone --recurse-submodules git@github.com:learnsyslab/crazyflow.git + git clone https://github.com/learnsyslab/crazyflow.git cd crazyflow pixi shell -e tests ``` @@ -41,7 +41,7 @@ JAX defaults to CPU-only execution. The `gpu` extra swaps in `jax[cuda12]`, enab ## Developer install -[Pixi](https://pixi.sh/) creates a fully reproducible environment. This variant installs `crazyflow`, `drone_models`, and `drone_controllers` in editable mode from the `submodules/` folder. Any source change takes effect immediately without reinstalling. Recommended for contributors and researchers who modify the simulator. +[Pixi](https://pixi.sh/) creates a fully reproducible environment. This variant installs `crazyflow` in editable mode. Any source change takes effect immediately without reinstalling. Recommended for contributors and researchers who modify the simulator. ## Testing diff --git a/docs/get-started/quick-start.md b/docs/get-started/quick-start.md index 26e9de5..e578790 100644 --- a/docs/get-started/quick-start.md +++ b/docs/get-started/quick-start.md @@ -4,7 +4,7 @@ This page walks through a complete minimal workflow: create a simulator, send a ## Create a simulator -`Sim` is the top-level object. All configuration is provided at construction time: physics model, control mode, simulation frequency, number of parallel worlds, and number of drones per world. +`Sim` is the top-level object. All configuration is provided at construction time: dynamics, control mode, simulation frequency, number of parallel worlds, and number of drones per world. ```python from crazyflow.sim import Sim @@ -45,7 +45,7 @@ cmd[0, 0, 2] = 0.5 # target height: 0.5 m ## Step the simulation -`state_control` stages the command. `step` advances the simulation by the given number of physics steps. Calling `sim.step(sim.freq // sim.control_freq)` advances exactly one control cycle. +`state_control` stages the command. `step` advances the simulation by the given number of dynamics steps. Calling `sim.step(sim.freq // sim.control_freq)` advances exactly one control cycle. ```python import numpy as np @@ -136,5 +136,5 @@ pos = sim.data.states.pos[0, :, :] # (4, 3) — all 4 drones in world 0 - [Object-Oriented API](../user-guide/oo-api.md) — all control modes, rendering, and reset - [Functional API](../user-guide/functional-api.md) — purely functional interface for use inside JAX transformations -- [Physics Models](../user-guide/physics-models.md) — choosing between first-principles and fitted models +- [Dynamics](../user-guide/dynamics/index.md) — choosing between first-principles and fitted dynamics - [Examples](../examples/index.md) — runnable scripts diff --git a/docs/index.md b/docs/index.md index fe87258..ac9f761 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,7 +6,7 @@ **Fast, parallelizable simulations of Crazyflie drones with JAX.** -Crazyflow is a research simulator for Crazyflie-style quadrotors that runs millions of independent environments in parallel on CPU or GPU. It is built on JAX, exposes a differentiable dynamics pipeline, and ships identified models for the Crazyflie 2.x family. +Crazyflow is a research simulator for Crazyflie-style quadrotors that runs millions of independent environments in parallel on CPU or GPU. It is built on JAX, exposes a differentiable dynamics pipeline, and ships identified dynamics for the Crazyflie 2.x family. --- @@ -63,7 +63,7 @@ Crazyflow is a research simulator for Crazyflie-style quadrotors that runs milli allowfullscreen loading="lazy" > - + @@ -86,11 +86,11 @@ Crazyflow is a research simulator for Crazyflie-style quadrotors that runs milli ## Supported drones -All models come from the [drone-models](https://learnsyslab.github.io/drone-models/) library. Available configurations: `cf2x_L250`, `cf2x_P250`, `cf2x_T350`, `cf21B_500`, and any model returned by `drone_models.available_drones()`. +All drone configurations are bundled with `crazyflow.dynamics`. Available configurations: `cf2x_L250`, `cf2x_P250`, `cf2x_T350`, `cf21B_500`, and any drone returned by `crazyflow.available_drones`. --- @@ -109,7 +109,7 @@ All models come from the [drone-models](https://learnsyslab.github.io/drone-mode crazyflow_experiments commit 6b65eeedefe32690f1e5ca7818d62439314f0de5 --> -Throughput for one drone across parallel worlds, first-principles physics. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 4090. +Throughput for one drone across parallel worlds, first-principles dynamics. CPU: AMD Ryzen 9 7950X. GPU: NVIDIA RTX 4090. ```vegalite { @@ -203,9 +203,9 @@ GPU throughput across `n_worlds` and `n_drones` (RTX 4090). Empty cells exceed a ## Why Crazyflow -Most simulators offer either vectorized environments for RL training or multi-drone swarm simulation — rarely both, and rarely with accurate onboard flight dynamics for every agent. Crazyflow is built around both simultaneously. The entire simulator is structured around an `n_worlds × n_drones` batch dimension: `n_worlds` gives you massively parallel independent environments, and `n_drones` gives you full swarm simulation inside each one, each drone running its own accurate, identified flight model and control stack. Scaling to millions of parallel instances requires no code changes. +Most simulators offer either vectorized environments for RL training or multi-drone swarm simulation — rarely both, and rarely with accurate onboard flight dynamics for every agent. Crazyflow is built around both simultaneously. The entire simulator is structured around an `n_worlds × n_drones` batch dimension: `n_worlds` gives you massively parallel independent environments, and `n_drones` gives you full swarm simulation inside each one, each drone running its own accurate, identified flight dynamics and control stack. Scaling to millions of parallel instances requires no code changes. -Simulating the full Crazyflie firmware stack with GPU acceleration and differentiability is not possible with existing tools, so Crazyflow reimplements the entire dynamics and control stack in JAX. This gives accelerated, fully batchable simulation that runs on CPU and GPU without modification. Differentiability comes as a direct consequence: `jax.grad` works through physics, control, and integration without any manual gradient derivations, enabling gradient-based policy optimization, system identification, and sensitivity analysis out of the box. +Simulating the full Crazyflie firmware stack with GPU acceleration and differentiability is not possible with existing tools, so Crazyflow reimplements the entire dynamics and control stack in JAX. This gives accelerated, fully batchable simulation that runs on CPU and GPU without modification. Differentiability comes as a direct consequence: `jax.grad` works through dynamics, control, and integration without any manual gradient derivations, enabling gradient-based policy optimization, system identification, and sensitivity analysis out of the box. To make research possible rather than just evaluation, the simulator is designed to be fully open to modification. The step and reset pipelines are plain ordered dictionaries of pure JAX functions. Helper functions in `crazyflow.sim.pipeline` (`append_fn`, `prepend_fn`, `insert_fn_before`, `insert_fn_after`, `replace_fn`, `remove_fn`) let you safely add, reorder, and swap stages by name. There are no fixed hooks or plugin interfaces. You splice in your own dynamics, disturbances, randomization, or reward shaping at any point, and the JIT compiler fuses everything into a single kernel. diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index e951678..29da1f7 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -219,7 +219,7 @@ } } -/* Drone model grid */ +/* Drone grid */ .drone-grid table { margin: 0 auto; } @@ -269,4 +269,4 @@ max-width: 100%; max-height: 560px; object-fit: contain; -} +} \ No newline at end of file diff --git a/docs/user-guide/control/batching.md b/docs/user-guide/control/batching.md new file mode 100644 index 0000000..4d12c91 --- /dev/null +++ b/docs/user-guide/control/batching.md @@ -0,0 +1,41 @@ +# Batching + +All controllers are built on Array API operations that broadcast over leading dimensions. Add a leading batch dimension to your state and command arrays and the controller evaluates all instances in a single call, with no loops and no special API. + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") + +N = 100 +pos = np.zeros((N, 3)) +quat = np.tile(np.array([0.0, 0.0, 0.0, 1.0]), (N, 1)) +vel = np.zeros((N, 3)) +cmd = np.zeros((N, 13)) + +rpyt, int_pos_err = ctrl(pos, quat, vel, cmd) +rpyt.shape # (100, 4) +``` + +## Higher-dimensional batches + +Any number of leading dimensions works. A common pattern is a grid of environments, each containing multiple drones: + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") + +# 10 environments, 5 drones each +pos = np.zeros((10, 5, 3)) +quat = np.broadcast_to(np.array([0.0, 0.0, 0.0, 1.0]), (10, 5, 4)).copy() +vel = np.zeros((10, 5, 3)) +cmd = np.zeros((10, 5, 13)) + +rpyt, _ = ctrl(pos, quat, vel, cmd) +rpyt.shape # (10, 5, 4) +``` diff --git a/docs/user-guide/control/controllers.md b/docs/user-guide/control/controllers.md new file mode 100644 index 0000000..6c04330 --- /dev/null +++ b/docs/user-guide/control/controllers.md @@ -0,0 +1,27 @@ +# Controllers + +A controller is a function that maps the current drone state and a command to actuator outputs. Every controller in `crazyflow.control` is: + +- **A pure function**: no hidden state; integral errors are explicit return values you pass back on the next call +- **Array-API compatible**: works identically with NumPy, JAX, PyTorch, or any compliant library +- **Batchable**: add leading dimensions to any input array and the function evaluates all instances at once + +## The Mellinger pipeline + +The Mellinger controller [[1]](#references) is split into three stages that form a pipeline. Each stage can be used on its own, or all three can be chained to convert a full-state setpoint into individual motor speeds. + +| Stage | Function | Takes | Produces | +|---|---|---|---| +| 1 | [`state2attitude`](mellinger.md#state-to-attitude) | State + 13-element setpoint | RPYT command + position integral error | +| 2 | [`attitude2force_torque`](mellinger.md#attitude-to-force-torque) | Attitude + RPYT command | Collective force, body torques + angular velocity integral error | +| 3 | [`force_torque2rotor_vel`](mellinger.md#force-torque-to-rotor-velocities) | Force + torques | 4 motor speeds [RPM] | + +## Available controllers + +| Module | Controller | Stages | +|---|---|---| +| `crazyflow.control.mellinger` | Mellinger | `state2attitude`, `attitude2force_torque`, `force_torque2rotor_vel` | + +## References + +[1] D. Mellinger and V. Kumar, "Minimum snap trajectory generation and control for quadrotors," ICRA 2011, doi: 10.1109/ICRA.2011.5980409. diff --git a/docs/user-guide/control-modes.md b/docs/user-guide/control/index.md similarity index 68% rename from docs/user-guide/control-modes.md rename to docs/user-guide/control/index.md index 57e82ea..95a8f4a 100644 --- a/docs/user-guide/control-modes.md +++ b/docs/user-guide/control/index.md @@ -60,10 +60,10 @@ sim.step(sim.freq // sim.control_freq) ## Attitude control ```python -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(control=Control.attitude, physics=Physics.so_rpy, attitude_freq=500) +sim = Sim(control=Control.attitude, dynamics=Dynamics.so_rpy, attitude_freq=500) sim.reset() ``` @@ -80,10 +80,10 @@ For a hover command, set thrust to `mass × g`: ```python import numpy as np -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(control=Control.attitude, physics=Physics.so_rpy) +sim = Sim(control=Control.attitude, dynamics=Dynamics.so_rpy) sim.reset() mass = float(sim.data.params.mass[0, 0, 0]) @@ -96,7 +96,7 @@ sim.step(sim.freq // sim.control_freq) ## Force-torque control -Direct force and torque input. Requires `Physics.first_principles`. +Direct force and torque input. Requires `Dynamics.first_principles`. Command shape: `(n_worlds, n_drones, 4)` @@ -109,10 +109,10 @@ Command shape: `(n_worlds, n_drones, 4)` ```python import numpy as np -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(control=Control.force_torque, physics=Physics.first_principles) +sim = Sim(control=Control.force_torque, dynamics=Dynamics.first_principles) sim.reset() mass = float(sim.data.params.mass[0, 0, 0]) @@ -125,7 +125,7 @@ sim.step(1) ## Rotor velocity control -Direct motor commands. Requires `Physics.first_principles`. +Direct motor commands. Requires `Dynamics.first_principles`. Command shape: `(n_worlds, n_drones, 4)` @@ -137,10 +137,10 @@ The hover RPM for `cf2x_L250` is approximately 15 000 RPM, but the exact value d ```python import numpy as np -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(control=Control.rotor_vel, physics=Physics.first_principles) +sim = Sim(control=Control.rotor_vel, dynamics=Dynamics.first_principles) sim.reset() cmd = np.full((1, 1, 4), 15_000.0, dtype=np.float32) @@ -151,18 +151,29 @@ sim.step(1) ## Control frequency -Each control mode has its own update rate. The physics tick (`freq`) is always the fastest. +Each control mode has its own update rate. The dynamics tick (`freq`) is always the fastest. | Mode | Rate argument | Default | |---|---|---| | `state` | `state_freq` | 100 Hz | | `attitude` | `attitude_freq` | 500 Hz | | `force_torque` | `force_torque_freq` | 500 Hz | -| `rotor_vel` | — | every physics step | +| `rotor_vel` | — | every dynamics step | -The simulator applies a new command only when the control tick fires. Between ticks, the previous command is held. The number of physics steps per control tick is `freq // control_freq`. +The simulator applies a new command only when the control tick fires. Between ticks, the previous command is held. The number of dynamics steps per control tick is `freq // control_freq`. + +## Using the controllers standalone + +The control modes above are how the simulator drives the onboard controllers. Those controllers also live in `crazyflow.control` as a self-contained library of pure functions, usable on their own for control design, learning-based policies, or as a reference implementation, independent of `Sim`. The following guides cover that standalone API: + +- [Controllers](controllers.md): the controller interface and the Mellinger pipeline +- [Mellinger controller](mellinger.md): the three stages, their inputs and outputs +- [Parametrization](parametrize.md): binding a controller to a drone configuration +- [Integral errors](integral-errors.md): carrying controller state across calls +- [Batching](batching.md): evaluating many drones at once +- [JIT compilation](jit.md): compiling controllers with `jax.jit` ## Next steps -- [Functional API](functional-api.md) — running control inside JIT with `F.controllable` -- [Physics Models](physics-models.md) — compatibility between physics and control modes +- [Functional API](../functional-api.md): running control inside JIT with `F.controllable` +- [Dynamics](../dynamics/index.md): compatibility between dynamics and control modes diff --git a/docs/user-guide/control/integral-errors.md b/docs/user-guide/control/integral-errors.md new file mode 100644 index 0000000..2f6223e --- /dev/null +++ b/docs/user-guide/control/integral-errors.md @@ -0,0 +1,79 @@ +# Integral errors + +The Mellinger controller uses integral terms to reject steady-state errors. Because every controller in `crazyflow.control` is a **pure function**, integral state cannot be stored inside the function. It is an explicit return value that you pass back on the next call. + +## Initialisation + +You have two ways to start the integral error at zero: + +1. Pass `None` (or omit the argument). The controller then creates the zero array internally. +2. Create the zero array yourself and pass it explicitly. The integral error has the same shape as the position (for `pos_err_i`) or angular velocity (for `r_int_error`), so a zero array of that shape works. + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +cmd = np.zeros(13) + +# Option 1: let the controller initialise the integral error. +rpyt, pos_err_i = ctrl(pos, quat, vel, cmd, pos_err_i=None) + +# Option 2: initialise it yourself. +rpyt, pos_err_i = ctrl(pos, quat, vel, cmd, pos_err_i=np.zeros(3)) +``` + +Under `jax.jit`, prefer option 2. Passing `None` on the first call and an array on the following calls changes the argument structure, so JAX traces and compiles the function twice. Passing a zero array from the start keeps the input structure constant and compiles only once. + +## Carrying errors across timesteps + +Pass the returned error straight back as `pos_err_i` on the next call: + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +cmd = np.zeros(13) +cmd[0] = 1.0 # 1 m setpoint error in x + +pos_err_i = None +for _ in range(10): + rpyt, pos_err_i = ctrl(pos, quat, vel, cmd, pos_err_i=pos_err_i) # carry forward + +# After 10 steps at 100 Hz, integral error is approximately 10 * 0.01 = 0.1 m. +``` + +## Both stages have integral errors + +`state2attitude` tracks position error via `pos_err_i`. `attitude2force_torque` tracks angular velocity error via `r_int_error`. Manage them independently: + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import attitude2force_torque, state2attitude + +state_ctrl = parametrize(state2attitude, "cf2x_L250") +att_ctrl = parametrize(attitude2force_torque, "cf2x_L250") + +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +ang_vel = np.zeros(3) +cmd = np.zeros(13) + +pos_err_i = None +r_int_error = None + +for _ in range(5): + rpyt, pos_err_i = state_ctrl(pos, quat, vel, cmd, pos_err_i=pos_err_i) + force, torque, r_int_error = att_ctrl(quat, ang_vel, rpyt, r_int_error=r_int_error) +``` diff --git a/docs/user-guide/control/jit.md b/docs/user-guide/control/jit.md new file mode 100644 index 0000000..7aa707f --- /dev/null +++ b/docs/user-guide/control/jit.md @@ -0,0 +1,66 @@ +# JIT compilation + +Every controller is a pure function with no hidden state or side effects. In addition, they are implemented exclusively with Array API operations that are compatible with lazy JIT frameworks. Together, these two properties mean that every controller can be JIT compiled without any modification. + +```python +import jax +import jax.numpy as jnp +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250", xp=jnp) +jit_ctrl = jax.jit(ctrl) + +pos = jnp.zeros(3) +quat = jnp.array([0.0, 0.0, 0.0, 1.0]) +vel = jnp.zeros(3) +cmd = jnp.zeros(13) + +rpyt, int_pos_err = jit_ctrl(pos, quat, vel, cmd) +``` + +## Integral errors under JIT + +Integral errors are regular arrays and are handled as JAX pytree leaves, so they pass through `jax.jit` without any special treatment. + +```python +import jax +import jax.numpy as jnp +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250", xp=jnp) +jit_ctrl = jax.jit(ctrl) + +pos = jnp.zeros(3) +quat = jnp.array([0.0, 0.0, 0.0, 1.0]) +vel = jnp.zeros(3) +cmd = jnp.zeros(13) + +pos_err_i = jnp.zeros(3) # initialise to zero, so the function compiles only once +for _ in range(10): + rpyt, pos_err_i = jit_ctrl(pos, quat, vel, cmd, pos_err_i=pos_err_i) +``` + +## Batched JIT + +Batching and JIT compose directly. Add leading dimensions to the state arrays and the same compiled function handles the entire batch. + +```python +import jax +import jax.numpy as jnp +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250", xp=jnp) +jit_ctrl = jax.jit(ctrl) + +N = 1_000 +pos = jnp.zeros((N, 3)) +quat = jnp.broadcast_to(jnp.array([0.0, 0.0, 0.0, 1.0]), (N, 4)) +vel = jnp.zeros((N, 3)) +cmd = jnp.zeros((N, 13)) + +rpyt, _ = jit_ctrl(pos, quat, vel, cmd) +rpyt.shape # (1000, 4) +``` diff --git a/docs/user-guide/control/mellinger.md b/docs/user-guide/control/mellinger.md new file mode 100644 index 0000000..07136e5 --- /dev/null +++ b/docs/user-guide/control/mellinger.md @@ -0,0 +1,153 @@ +# Mellinger controller + +The Mellinger controller converts a full-state setpoint into individual motor speeds through three chained pure functions. The implementation closely follows the Crazyflie firmware to minimise sim-to-real gap. + +## State representation + +All three stages share the same state convention: + +| Variable | Shape | Units | Description | +|---|---|---|---| +| `pos` | `(..., 3)` | m | Position in world frame | +| `quat` | `(..., 4)` | | Attitude as unit quaternion, scalar-last `xyzw` | +| `vel` | `(..., 3)` | m/s | Linear velocity in world frame | +| `ang_vel` | `(..., 3)` | rad/s | Angular velocity in body frame | + +## Stage 1: State to attitude {#state-to-attitude} + +`state2attitude` is the position control loop. It converts a full-state setpoint into an attitude and collective thrust command (RPYT). + +**Inputs:** + +| Argument | Shape | Description | +|---|---|---| +| `pos` | `(..., 3)` | Current position [m] | +| `quat` | `(..., 4)` | Current attitude, xyzw | +| `vel` | `(..., 3)` | Current velocity [m/s] | +| `cmd` | `(..., 13)` | Setpoint: `[x, y, z, vx, vy, vz, ax, ay, az, yaw, avx, avy, avz]` | +| `pos_err_i` | `(..., 3)` or `None` | Position integral error from the previous call. `None` initialises to zero | +| `ctrl_freq` | `float` | Control frequency in Hz (default 100) | + +**Outputs:** + +| Return | Shape | Description | +|---|---|---| +| `rpyt` | `(..., 4)` | Attitude + thrust: `[roll_rad, pitch_rad, yaw_rad, thrust_N]` | +| `pos_err_i` | `(..., 3)` | Position integral error. Pass back as `pos_err_i` on the next call | + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") + +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +cmd = np.zeros(13) # setpoint at origin, yaw = 0 + +rpyt, pos_err_i = ctrl(pos, quat, vel, cmd) +rpyt.shape # (4,) +pos_err_i.shape # (3,) +``` + +## Stage 2: Attitude to force/torque {#attitude-to-force-torque} + +`attitude2force_torque` is the attitude control loop. It converts an RPYT command into collective thrust and body-frame torques. + +**Inputs:** + +| Argument | Shape | Description | +|---|---|---| +| `quat` | `(..., 4)` | Current attitude, xyzw | +| `ang_vel` | `(..., 3)` | Current angular velocity in body frame [rad/s] | +| `cmd` | `(..., 4)` | RPYT from stage 1: `[roll_rad, pitch_rad, yaw_rad, thrust_N]` | +| `prev_ang_vel` | `(..., 3)` or `None` | Angular velocity from the previous call. `None` initialises to zero | +| `r_int_error` | `(..., 3)` or `None` | Angular velocity integral error from the previous call. `None` initialises to zero | +| `ctrl_freq` | `int` | Control frequency in Hz (default 500) | + +**Outputs:** + +| Return | Shape | Description | +|---|---|---| +| `force` | `(..., 1)` | Collective thrust [N] | +| `torque` | `(..., 3)` | Body-frame torques [N·m] | +| `r_int_error` | `(..., 3)` | Angular velocity integral error. Pass back as `r_int_error` on the next call | + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import attitude2force_torque + +ctrl = parametrize(attitude2force_torque, "cf2x_L250") + +quat = np.array([0.0, 0.0, 0.0, 1.0]) # identity, no rotation +ang_vel = np.zeros(3) +cmd = np.array([0.0, 0.0, 0.0, 0.3]) # level attitude, 0.3 N thrust + +force, torque, r_int_err = ctrl(quat, ang_vel, cmd) +force.shape # (1,) +torque.shape # (3,) +``` + +## Stage 3: Force/torque to rotor velocities {#force-torque-to-rotor-velocities} + +`force_torque2rotor_vel` converts collective thrust and body-frame torques into individual motor speeds, accounting for the motor mixing matrix. + +**Inputs:** + +| Argument | Shape | Description | +|---|---|---| +| `force` | `(..., 1)` | Desired collective thrust [N] | +| `torque` | `(..., 3)` | Desired body-frame torques [N·m] | + +**Outputs:** + +| Return | Shape | Description | +|---|---|---| +| `rotor_speeds` | `(..., 4)` | Individual motor speeds [RPM] | + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import force_torque2rotor_vel + +ctrl = parametrize(force_torque2rotor_vel, "cf2x_L250") + +force = np.array([0.2]) # total thrust [N] +torque = np.zeros(3) # no corrective torque + +rotor_speeds = ctrl(force, torque) +rotor_speeds.shape # (4,) +``` + +## Chaining all three stages + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import ( + attitude2force_torque, + force_torque2rotor_vel, + state2attitude, +) + +state_ctrl = parametrize(state2attitude, "cf2x_L250") +att_ctrl = parametrize(attitude2force_torque, "cf2x_L250") +rotor_ctrl = parametrize(force_torque2rotor_vel, "cf2x_L250") + +pos = np.array([0.0, 0.0, 1.0]) # 1 m altitude +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +ang_vel = np.zeros(3) +cmd = np.zeros(13) +cmd[:3] = np.array([0.0, 0.0, 1.0]) # hover at 1 m + +rpyt, _ = state_ctrl(pos, quat, vel, cmd) +force, torque, _ = att_ctrl(quat, ang_vel, rpyt) +rotor_speeds = rotor_ctrl(force, torque) +rotor_speeds.shape # (4,) +``` + +Integral errors from each stage should be passed back on the next call. See [Integral errors](integral-errors.md) for the full pattern. diff --git a/docs/user-guide/control/parametrize.md b/docs/user-guide/control/parametrize.md new file mode 100644 index 0000000..1fc35cf --- /dev/null +++ b/docs/user-guide/control/parametrize.md @@ -0,0 +1,113 @@ +# Parametrize + +Every controller function takes physical parameters (gains, mass, mixing matrix, PWM bounds) as keyword-only arguments, and the exact values differ per drone. Rather than passing them at every call site, [`parametrize`][crazyflow.control.parametrize] loads them for a named drone and binds them upfront, so call sites only need to provide state and command. + +The parameters stay individually accessible after binding. Because they are plain keyword-argument defaults on a `functools.partial`, any of them can be overridden at call time, or batched across a set of environments, without re-parametrizing the function. This makes it straightforward to randomize physical properties across a simulated batch. + +```python +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") + +# Inspect what was bound +list(ctrl.keywords.keys()) +# ['mass', 'kp', 'kd', 'ki', 'gravity_vec', 'mass_thrust', +# 'int_err_max', 'thrust_max', 'pwm_max'] +``` + +## Overriding parameters at call time + +Because `parametrize` returns a `functools.partial`, the bound parameters are just keyword-argument defaults. Pass a different value at call time to override for that call only; `ctrl.keywords` is not modified: + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +cmd = np.zeros(13) + +# Simulate with a heavier drone for this call only. +rpyt, _ = ctrl(pos, quat, vel, cmd, mass=0.035) +``` + +To make a change persist across all future calls, mutate `ctrl.keywords` directly: + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") +ctrl.keywords["mass"] = np.float64(0.035) +``` + +!!! warning + `ctrl.keywords` is a mutable dict shared across all references to the same partial. Call `parametrize` again for an independent copy. + +## Available drone configurations + +The following configurations ship with pre-fitted parameters: + +| `drone` | Platform | +|---|---| +| `"cf2x_L250"` | Crazyflie 2.x, L250 props | +| `"cf2x_P250"` | Crazyflie 2.x, P250 props | +| `"cf2x_T350"` | Crazyflie 2.x, T350 props | +| `"cf21B_500"` | Crazyflie 2.1 Brushless, 500 props | + +Pass the drone name as a plain string: + +```python +import numpy as np +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250") +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +cmd = np.zeros(13) +rpyt, _ = ctrl(pos, quat, vel, cmd) +``` + +## Loading raw parameters + +Use [`load_params`][crazyflow.control.core.load_params] to inspect or override the values that `parametrize` would bind for a specific controller function: + +```python +from crazyflow.control.core import load_params +from crazyflow.control.mellinger import state2attitude + +params = load_params(state2attitude, "cf2x_L250") +float(params["mass"]) # 0.029 +``` + +## Switching array backends + +By default parameters are stored as NumPy arrays. Pass `xp` to convert them upfront, which avoids per-call conversion overhead in frameworks like PyTorch or JAX: + +```{ .python notest } +import torch +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250", xp=torch) +``` + +You can also specify a compute device: + +```python +import jax +import jax.numpy as jnp +from crazyflow.control import parametrize +from crazyflow.control.mellinger import state2attitude + +ctrl = parametrize(state2attitude, "cf2x_L250", xp=jnp, device=jax.devices("cpu")[0]) +``` + +The output backend is always inferred from the arrays you pass at call time, regardless of where the parameters live. diff --git a/docs/user-guide/dynamics/batching.md b/docs/user-guide/dynamics/batching.md new file mode 100644 index 0000000..b3ab6f0 --- /dev/null +++ b/docs/user-guide/dynamics/batching.md @@ -0,0 +1,99 @@ +# Batching and domain randomization + +All dynamics are built on Array API operations that broadcast over leading dimensions. There is no explicit batch size argument — just add a leading batch dimension to your state and command arrays and the dynamics evaluate all instances in a single call. This works identically across all backends. + +```python +import jax.numpy as jnp +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250", xp=jnp) + +N = 1_000 +pos = jnp.zeros((N, 3)) +quat = jnp.tile(jnp.array([0.0, 0.0, 0.0, 1.0]), (N, 1)) +vel = jnp.zeros((N, 3)) +ang_vel = jnp.zeros((N, 3)) +cmd = jnp.full((N, 4), 15_000.0) +rotor_vel = jnp.full((N, 4), 15_000.0) + +pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics( + pos, quat, vel, ang_vel, cmd, rotor_vel=rotor_vel +) +vel_dot.shape # (1000, 3) +``` + +## Higher-dimensional batches + +Any number of leading dimensions works. A common pattern is a grid of environments, each containing multiple drones: + +```python +import jax.numpy as jnp +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250", xp=jnp) +# 50 environments, 20 drones each +pos = jnp.zeros((50, 20, 3)) +quat = jnp.broadcast_to(jnp.array([0.0, 0.0, 0.0, 1.0]), (50, 20, 4)) +vel = jnp.zeros((50, 20, 3)) +ang_vel = jnp.zeros((50, 20, 3)) +rotor_vel = jnp.full((50, 20, 4), 12_000.0) +cmd = jnp.full((50, 20, 4), 15_000.0) + +vel_dot = dynamics(pos, quat, vel, ang_vel, cmd, rotor_vel)[2] +vel_dot.shape # (50, 20, 3) +``` + +## Domain randomization + +Training policies across a distribution of physical parameters — domain randomization — improves sim-to-real transfer. Because `parametrize` returns a `functools.partial`, physical parameters are just keyword argument defaults. There are two ways to vary them across a batch. + +**Option 1 — pass parameters as call-time kwargs.** This is the preferred pattern when using JIT compilation, because JAX traces the parameters as inputs rather than capturing them as constants. You can then draw fresh parameters each rollout without triggering a recompile. + +```python +import jax +import jax.numpy as jnp +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +N = 256 +key = jax.random.PRNGKey(0) + +pos, vel, ang_vel = jnp.zeros((N, 3)), jnp.zeros((N, 3)), jnp.zeros((N, 3)) +quat = jnp.tile(jnp.array([0.0, 0.0, 0.0, 1.0]), (N, 1)) +cmd = jnp.full((N, 4), 15_000.0) +rotor_vel = jnp.full((N, 4), 15_000.0) +dynamics = parametrize(dynamics, drone="cf2x_L250", xp=jnp) +nominal_mass = dynamics.keywords["mass"] +nominal_J = dynamics.keywords["J"] + + +@jax.jit +def step(pos, quat, vel, ang_vel, cmd, rotor_vel, mass, J, J_inv): + return dynamics( + pos, quat, vel, ang_vel, cmd, rotor_vel=rotor_vel, mass=mass, J=J, J_inv=J_inv + ) + + +key, k1, k2 = jax.random.split(key, 3) +mass_batch = nominal_mass * jax.random.uniform(k1, (N, 1), minval=0.9, maxval=1.1) +J_batch = nominal_J * jax.random.uniform(k2, (N, 3, 3), minval=0.9, maxval=1.1) +J_inv_batch = jnp.linalg.inv(J_batch) + +vel_dot = step(pos, quat, vel, ang_vel, cmd, rotor_vel, mass_batch, J_batch, J_inv_batch)[2] +``` + +**Option 2 — mutate `dynamics.keywords` directly.** Simpler when you don't need JIT or are happy to retrace. Replace a scalar parameter with a `(N,)` array and each element in the batch uses its own value. + +```{ .python continuation } +dynamics.keywords["mass"] = nominal_mass * mass_batch # shape (N,) +vel_dot = dynamics(pos, quat, vel, ang_vel, cmd, rotor_vel)[2] +``` + +!!! note + Matrix parameters like `J` have shape `(3, 3)`. To randomize per-drone, reshape to `(N, 3, 3)` and update `J_inv` accordingly. Scalar parameters like `mass` only need shape `(N,)`. + +--- + +So far everything has been numeric. Many control frameworks — MPC, trajectory optimization, and state estimators — require symbolic dynamics representations. The next page covers the CasADi symbolic variants. diff --git a/docs/user-guide/dynamics/dynamics-functions.md b/docs/user-guide/dynamics/dynamics-functions.md new file mode 100644 index 0000000..aba10e2 --- /dev/null +++ b/docs/user-guide/dynamics/dynamics-functions.md @@ -0,0 +1,121 @@ +# Dynamics functions + +A dynamics function takes the current state of the drone and a command, and returns the time derivatives of that state. Integrate those derivatives forward and you have a simulation step; evaluate them symbolically and you have an MPC model. The same function serves both purposes. + +Every dynamics in `crazyflow.dynamics` shares the same state representation: + +| Variable | Shape | Description | +|---|---|---| +| `pos` | `(3,)` | Position in world frame [m] | +| `quat` | `(4,)` | Attitude as unit quaternion, scalar-last `xyzw` | +| `vel` | `(3,)` | Linear velocity in world frame [m/s] | +| `ang_vel` | `(3,)` | Angular velocity in body frame [rad/s] | + +What differs between dynamics is the command interface, which parameters are needed, and how much physical detail is captured. The table below gives a quick overview — the sections that follow explain each one and when to reach for it. + +| Module | `cmd` input | Rotor dynamics | Key added params | +|---|---|---|---| +| `first_principles` | Motor RPMs `(4,)` | Yes | `rpm2thrust`, `rpm2torque`, `mixing_matrix`, `L`, `prop_inertia` | +| `so_rpy_rotor_drag` | rpyt `(4,)` | Yes | `thrust_time_coef`, `drag_matrix` | +| `so_rpy_rotor` | rpyt `(4,)` | Yes | `thrust_time_coef` | +| `so_rpy` | rpyt `(4,)` | No | — | + +## first_principles + +The full rigid-body dynamics, derived analytically from first principles. The command is four individual motor RPMs. It computes forces and torques from the RPMs using polynomial thrust and torque curves, applies the mixing matrix to find body-frame moments, and integrates using Newton–Euler equations. Propeller inertia and gyroscopic effects are included. No fitting to flight data is required — all parameters are physical constants you can measure or look up. + +Working at the rotor-velocity level means you need a controller that converts higher-level commands, such as position setpoints or attitude plus collective thrust, down to individual motor RPMs. The [controllers](../control/index.md) in `crazyflow.control` provide a matching set designed for exactly this interface. + +```python +import numpy as np +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250") + +pos, vel, ang_vel = np.zeros((3,)), np.zeros((3,)), np.zeros((3,)) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +cmd = np.full((4,), 15_000.0) +rotor_vel = np.full((4,), 12_000.0) + +pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics( + pos, quat, vel, ang_vel, + cmd, # shape (4,) — motor RPMs + rotor_vel, # shape (4,) — current motor RPMs; pass None to skip rotor dynamics +) +``` + +See [`crazyflow.dynamics.first_principles`][crazyflow.dynamics.first_principles] in the API reference for the full parameter list. + +## so_rpy_rotor_drag + +A fitted second-order dynamics where the command is `[roll_rad, pitch_rad, yaw_rad, thrust_N]` — the same interface used by most flight controller firmware. First-order thrust dynamics model motor spin-up delay, and a linear body-frame drag term accounts for aerodynamic resistance. All coefficients are identified from flight data rather than derived from physics, which makes it easy to calibrate and well-suited to real-time control. + +```{ .python continuation } +from crazyflow.dynamics.so_rpy_rotor_drag import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250") + +# Reuses pos, quat, vel, ang_vel from above; the command interface is what differs +cmd = np.array([0.0, 0.0, 0.0, 0.31]) # [roll_rad, pitch_rad, yaw_rad, thrust_N] +rotor_vel = np.array([0.31]) # shape (1,) — current thrust state [N]; None to skip thrust dynamics + +pos_dot, quat_dot, vel_dot, ang_vel_dot, rotor_vel_dot = dynamics( + pos, quat, vel, ang_vel, cmd, rotor_vel +) +``` + +## so_rpy_rotor + +The same as `so_rpy_rotor_drag` but without the drag term. Use this when aerodynamic drag is negligible — for example, in low-speed indoor flight — or when you want a slightly simpler dynamics to calibrate. + +```python +from crazyflow.dynamics.so_rpy_rotor import dynamics +``` + +## so_rpy + +The simplest dynamics: no rotor dynamics, no drag. The attitude dynamics are a fitted second-order system driven directly by the roll/pitch/yaw command. It does not accept a `rotor_vel` argument and returns four derivatives instead of five. This is the fastest to evaluate and the easiest to understand, making it a good baseline for control design and learning-based methods where simulation throughput matters most. + +```python +from crazyflow.dynamics.so_rpy import dynamics +``` + +## External disturbances + +All four dynamics accept optional `dist_f` (external force, world frame, N) and `dist_t` (external torque, body frame, N·m) arguments. These are useful for modelling wind, contact forces, or other perturbations without modifying the dynamics itself. + +```python +import numpy as np +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.so_rpy import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250") +pos, vel, ang_vel = np.zeros((3,)), np.zeros((3,)), np.zeros((3,)) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +cmd = np.array([0.0, 0.0, 0.0, 0.31]) +dist_f = np.array([0.05, 0.0, 0.0]) # 50 mN headwind [N] +dist_t = np.zeros(3) + +# so_rpy has no rotor dynamics, so it returns four derivatives +pos_dot, quat_dot, vel_dot, ang_vel_dot = dynamics( + pos, quat, vel, ang_vel, cmd, dist_f=dist_f, dist_t=dist_t +) +``` + +## Checking rotor dynamics support + +If you are writing code that works with multiple dynamics, [`dynamics_features`][crazyflow.dynamics.dynamics_features] tells you programmatically whether a given dynamics function supports rotor dynamics. + +```python +from crazyflow.dynamics import dynamics_features +from crazyflow.dynamics.first_principles import dynamics as fp +from crazyflow.dynamics.so_rpy import dynamics as srpy + +dynamics_features(fp) # {'rotor_dynamics': True} +dynamics_features(srpy) # {'rotor_dynamics': False} +``` + +--- + +With an understanding of the available dynamics, the next step is binding one to a specific drone configuration. That's what [`parametrize`](parametrize.md) does. diff --git a/docs/user-guide/dynamics/index.md b/docs/user-guide/dynamics/index.md new file mode 100644 index 0000000..60d86f9 --- /dev/null +++ b/docs/user-guide/dynamics/index.md @@ -0,0 +1,98 @@ +# Dynamics + +Crazyflow supports four dynamics, selectable via the `Dynamics` enum. All dynamics share the same state representation and control interface, so you can swap them at construction time without changing any other code. + +```python +from crazyflow.sim import Sim, Dynamics + +sim = Sim(dynamics=Dynamics.first_principles) +``` + +## Available dynamics + +| Dynamics | Enum value | Command input | Description | +|---|---|---|---| +| First principles | `Dynamics.first_principles` | Rotor RPM | Full analytical dynamics with identified parameters | +| SO(3) + RPY | `Dynamics.so_rpy` | Roll/pitch/yaw + thrust | Simplified fitted dynamics | +| SO(3) + RPY + rotor | `Dynamics.so_rpy_rotor` | Roll/pitch/yaw + thrust | Adds first-order rotor dynamics | +| SO(3) + RPY + rotor + drag | `Dynamics.so_rpy_rotor_drag` | Roll/pitch/yaw + thrust | Adds translational and rotational drag | + +`Dynamics.default` resolves to `Dynamics.first_principles`. + +## First-principles dynamics + +The first-principles dynamics derives forces and torques analytically from motor speeds using identified physical parameters: mass, arm length, propeller constants, and the full inertia tensor. It operates at the rotor-velocity level and is the most accurate dynamics for sim-to-real transfer. + +```python +from crazyflow.sim import Sim, Dynamics +from crazyflow.control import Control + +# Force-torque and rotor_vel control modes require first_principles +sim = Sim( + dynamics=Dynamics.first_principles, + control=Control.rotor_vel, +) +sim.reset() +``` + +Parameters accessible through `sim.data.params`: + +| Parameter | Description | +|---|---| +| `mass` | Drone mass, kg | +| `J` | Inertia matrix, kg·m² | +| `L` | Motor arm length, m | +| `rpm2thrust` | Thrust coefficient, N/(RPM²) | +| `rpm2torque` | Torque coefficient, Nm/(RPM²) | +| `mixing_matrix` | Maps rotor RPMs² to [thrust, tx, ty, tz] | +| `rotor_dyn_coef` | First-order rotor time constant | + +## Fitted dynamics (so_rpy family) + +The `so_rpy` dynamics are identified from flight data using a small number of flight minutes. They take higher-level commands (roll/pitch/yaw setpoints + collective thrust in Newtons) and are faster to simulate because they skip the rotor-velocity level. + +These dynamics are a good choice when: + +- You are training RL agents and want speed over fidelity +- Your controller outputs attitude targets (as most Crazyflie firmware does) +- You do not need rotor-level detail + +```python +from crazyflow.sim import Sim, Dynamics +from crazyflow.control import Control + +sim = Sim( + dynamics=Dynamics.so_rpy_rotor_drag, # most accurate of the fitted family + control=Control.attitude, +) +sim.reset() +``` + +The `so_rpy_rotor_drag` variant includes translational drag, which captures the velocity-dependent deceleration effect visible in aggressive flights. It is the recommended fitted dynamics for sim-to-real experiments. + +## Control mode compatibility + +| Dynamics | `Control.state` | `Control.attitude` | `Control.force_torque` | `Control.rotor_vel` | +|---|---|---|---|---| +| `first_principles` | ✓ | ✓ | ✓ | ✓ | +| `so_rpy` | ✓ | ✓ | ✗ | ✗ | +| `so_rpy_rotor` | ✓ | ✓ | ✗ | ✗ | +| `so_rpy_rotor_drag` | ✓ | ✓ | ✗ | ✗ | + +!!! warning + Using `Control.force_torque` or `Control.rotor_vel` with a fitted dynamics raises `ConfigError` at construction time. + +## Using the dynamics standalone + +The `Dynamics` enum above is how you select a dynamics *inside the simulator*. The dynamics also live in `crazyflow.dynamics` as a self-contained library of pure functions, usable on their own for state estimation, control design, or as MPC models independent of `Sim`. The following guides cover that standalone API: + +- [Dynamics functions](dynamics-functions.md) — the state representation, the four dynamics functions, and external disturbances +- [Parametrization](parametrize.md) — binding a dynamics function to a drone configuration +- [Batching & domain randomization](batching.md) — evaluating many drones at once and randomizing parameters +- [Symbolic dynamics](symbolic.md) — CasADi expressions for MPC and estimation +- [System identification](system-identification.md) — fitting dynamics coefficients from flight data + +## Next steps + +- [Control Modes](../control/index.md) — command shapes and the control hierarchy +- [Object-Oriented API](../oo-api.md) — full constructor arguments diff --git a/docs/user-guide/dynamics/parametrize.md b/docs/user-guide/dynamics/parametrize.md new file mode 100644 index 0000000..b77676c --- /dev/null +++ b/docs/user-guide/dynamics/parametrize.md @@ -0,0 +1,127 @@ +# Parametrize + +Each dynamics function has a large number of keyword-only parameters — mass, inertia matrix, thrust curves, drag coefficients, and so on. Passing all of them at every call would be impractical. [`parametrize`][crazyflow.dynamics.parametrize] solves this by loading the parameters for a named drone configuration and returning a [`functools.partial`](https://docs.python.org/3/library/functools.html#functools.partial) with those parameters pre-filled. You then call the returned object with just the state and command arrays. + +```python +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250") + +# Inspect what was pre-filled +list(dynamics.keywords.keys()) +# ['mass', 'L', 'prop_inertia', 'gravity_vec', 'J', 'J_inv', +# 'rpm2thrust', 'rpm2torque', 'mixing_matrix', 'rotor_dyn_coef', 'drag_matrix'] +``` + +## Available drone configurations + +The following configurations ship with pre-fitted parameters. They cover both the brushed Crazyflie 2.x series and the brushless Crazyflie 2.1: + +```python +from crazyflow.drones import available_drones + +available_drones # ('cf2x_L250', 'cf2x_P250', 'cf2x_T350', 'cf21B_500') +``` + +| `drone` | Platform | +|---|---| +| `cf2x_L250` | Crazyflie 2.x, L250 props | +| `cf2x_P250` | Crazyflie 2.x, P250 props | +| `cf2x_T350` | Crazyflie 2.x, T350 props | +| `cf21B_500` | Crazyflie 2.1 Brushless, 500 props | + +If your drone is not listed, you can identify the parameters from flight data using the [system identification pipeline](system-identification.md) and inject them into any dynamics. + +## Switching array backends + +By default, `parametrize` stores parameters as NumPy arrays. For frameworks that would otherwise need to convert those arrays on every call — such as PyTorch, where NumPy arrays must become tensors — passing `xp` converts the parameters upfront. The backend of the outputs is always inferred from whatever arrays you pass in at call time. + +```{ .python notest } +import torch +import jax.numpy as jnp +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics_torch = parametrize(dynamics, drone="cf2x_L250", xp=torch) +dynamics_jax = parametrize(dynamics, drone="cf2x_L250", xp=jnp) +``` + +You can also specify a compute device — for example, to move JAX parameters to GPU at construction time: + +```{ .python notest } +import jax +import jax.numpy as jnp +dynamics_gpu = parametrize( + dynamics, drone="cf2x_L250", xp=jnp, device=jax.devices("gpu")[0] +) +``` + +## Overriding parameters at call time + +Because `parametrize` returns a `functools.partial`, the stored parameters are just keyword argument defaults. You can override any of them for a single call by passing a new value as a keyword argument — the call-time value takes precedence and the stored defaults are unchanged. + +```python +import numpy as np +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250") + +pos = np.zeros(3) +quat = np.array([0.0, 0.0, 0.0, 1.0]) +vel = np.zeros(3) +ang_vel = np.zeros(3) +rotor_vel = np.zeros(4) +cmd = np.zeros(4) + +# Simulate with a 10 g payload for this call only — dynamics.keywords is not modified. +pos_dot, *_ = dynamics(pos, quat, vel, ang_vel, cmd, rotor_vel, mass=0.0419) +``` + +This becomes particularly useful for domain randomization: instead of baking randomized parameters into the partial, you can pass a batch of them as call-time arguments and keep the step function JIT-compiled across parameter changes. See [Batching & domain randomization](batching.md) for the full pattern. + +## Mutating stored parameters + +You can also modify `dynamics.keywords` directly for changes that should persist across all future calls: + +```python +import numpy as np +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.first_principles import dynamics + +dynamics = parametrize(dynamics, drone="cf2x_L250") +dynamics.keywords["mass"] = np.float64(0.040) # heavier drone — applies to every call +``` + +!!! warning + `dynamics.keywords` is a mutable dict shared across all references to the same partial. Modifying it affects every call. Call `parametrize` again for an independent copy. + +## Selecting dynamics programmatically + +`available_dynamics` is a dict mapping dynamics names to their unparametrized functions. This is useful when selecting a dynamics by name. + +```python +from crazyflow.dynamics import available_dynamics, parametrize + +list(available_dynamics) # ['first_principles', 'so_rpy', 'so_rpy_rotor', 'so_rpy_rotor_drag'] + +dynamics = available_dynamics["so_rpy_rotor_drag"] +parametrized_dynamics = parametrize(dynamics, drone="cf2x_T350") +``` + +## Loading raw parameters + +If you need the parameter values directly, for example, to pass them to [`symbolic_dynamics`](symbolic.md), use [`load_params`][crazyflow.dynamics.core.load_params]: + +```python { .python continuation } +from crazyflow.dynamics.core import load_params + +params = load_params(dynamics, "cf2x_L250") +params["mass"] # 0.0319 +params["J_inv"] # array([...]) +``` + +--- + +With a parametrized dynamics in hand, you can evaluate a single state. The next page covers running many drones simultaneously by adding batch dimensions — and how to randomize physical parameters across that batch for domain randomization. diff --git a/docs/user-guide/dynamics/symbolic.md b/docs/user-guide/dynamics/symbolic.md new file mode 100644 index 0000000..2b052da --- /dev/null +++ b/docs/user-guide/dynamics/symbolic.md @@ -0,0 +1,107 @@ +# Symbolic dynamics (CasADi) + +Optimization-based control methods like nonlinear MPC, trajectory optimization and moving-horizon estimation need symbolic dynamics: an expression graph the solver can differentiate through and evaluate analytically. Every dynamics in `crazyflow.dynamics` has a `symbolic_dynamics` function that returns [CasADi](https://web.casadi.org/) `MX` expressions, validated to be numerically equivalent to the numeric `dynamics` implementation. + +The return signature is always `(X_dot, X, U, Y)`: + +- `X` — state vector (CasADi `MX` column vector) +- `U` — input vector +- `X_dot` — state derivative, as a symbolic expression in `X` and `U` +- `Y` — output vector (position + attitude) + +Physical parameters are bound with [`parametrize`][crazyflow.dynamics.parametrize], exactly as for the numeric dynamics. + +## first_principles + +```python +import casadi as cs +from crazyflow.dynamics.first_principles import symbolic_dynamics +from crazyflow.dynamics.core import parametrize + +symbolic_dynamics = parametrize(symbolic_dynamics, "cf2x_L250") + +# include rotor velocity in the state vector +X_dot, X, U, Y = symbolic_dynamics(model_rotor_vel=True, model_dist_f=False, model_dist_t=False) + +f = cs.Function("f", [X, U], [X_dot]) +``` + +State vector layout with `model_rotor_vel=True`: + +| Indices | Variable | Units | +|---|---|---| +| 0–2 | `pos` | m | +| 3–6 | `quat` (xyzw) | — | +| 7–9 | `vel` | m/s | +| 10–12 | `ang_vel` | rad/s | +| 13–16 | `rotor_vel` | RPM | + +See [`crazyflow.dynamics.first_principles`][crazyflow.dynamics.first_principles] in the API reference for the full list of accepted parameters. + +## Fitted dynamics — quaternion form + +`symbolic_dynamics` on the fitted dynamics converts them to quaternion + angular velocity state, matching the `dynamics` function signature. This makes it straightforward to swap between `first_principles` and a fitted dynamics in a solver without changing the state layout. + +```python +from crazyflow.dynamics.so_rpy_rotor_drag import symbolic_dynamics +from crazyflow.dynamics.core import parametrize + +X_dot, X, U, Y = parametrize(symbolic_dynamics, "cf2x_L250")(model_rotor_vel=True) +``` + +## Fitted dynamics — Euler form + +The fitted dynamics also expose `symbolic_dynamics_euler`, which works directly in roll/pitch/yaw + RPY-rate state. This is the natural representation of these dynamics — they are fitted in Euler angles — and it avoids the trigonometric overhead of converting to and from quaternions inside the solver. For most NMPC applications on the fitted dynamics, this is the variant to use. + +```python +from crazyflow.dynamics.so_rpy_rotor_drag import symbolic_dynamics_euler +from crazyflow.dynamics.core import parametrize + +symbolic_dynamics_euler = parametrize(symbolic_dynamics_euler, "cf2x_L250") +X_dot, X, U, Y = symbolic_dynamics_euler(model_rotor_vel=True) +``` + +State vector layout with `model_rotor_vel=True`: + +| Indices | Variable | Units | +|---|---|---| +| 0–2 | `pos` | m | +| 3–5 | `rpy` (roll/pitch/yaw) | rad | +| 6–8 | `vel` | m/s | +| 9–11 | `drpy` (RPY rates) | rad/s | +| 12 | thrust state | N | + +## Wrapping for Acados / IPOPT + +Both functions return raw CasADi expressions. Wrap them in a `cs.Function` to pass to any CasADi-based solver: + +```python +import casadi as cs +from crazyflow.dynamics.so_rpy_rotor_drag import symbolic_dynamics_euler +from crazyflow.dynamics.core import parametrize + +symbolic_dynamics_euler = parametrize(symbolic_dynamics_euler, "cf2x_L250") +X_dot, X, U, Y = symbolic_dynamics_euler(model_rotor_vel=True) + +f = cs.Function("f", [X, U], [X_dot]) +# Pass f directly to Acados, IPOPT, or any CasADi-based solver +``` + +## With disturbance states + +Setting `model_dist_f=True` or `model_dist_t=True` appends the disturbance vectors to the state, which is useful for augmented-state estimators: + +```python +from crazyflow.dynamics.first_principles import symbolic_dynamics +from crazyflow.dynamics.core import parametrize + +symbolic_dynamics = parametrize(symbolic_dynamics, "cf2x_L250") + +# dist_f (3,) and dist_t (3,) appended to state +X_dot, X, U, Y = symbolic_dynamics(model_rotor_vel=True, model_dist_f=True, model_dist_t=True) +# X is now 17 + 3 + 3 = 23 elements long +``` + +--- + +The dynamics covered so far all come with pre-fitted parameters for the supported Crazyflie platforms. For other drones, the next page explains how to extract parameters from your own flight data. diff --git a/docs/user-guide/dynamics/system-identification.md b/docs/user-guide/dynamics/system-identification.md new file mode 100644 index 0000000..7818ecd --- /dev/null +++ b/docs/user-guide/dynamics/system-identification.md @@ -0,0 +1,117 @@ +# System identification + +If your drone is not among the [supported configurations](parametrize.md#available-drone-configurations), or if you want to refine the existing parameters with your own hardware, the system identification pipeline fits the dynamics coefficients from recorded flight data. It handles data preprocessing, derivative estimation, and least-squares parameter fitting for both translational and rotational dynamics. + +The pipeline is part of `crazyflow.dynamics.utils`. Plotting the fit (`plot=True`) additionally requires `matplotlib`. + +## Required data format + +The pipeline expects a Python dict of NumPy arrays assembled from your flight log. The keys below are required by [`preprocessing`][crazyflow.dynamics.utils.data_utils.preprocessing]: + +| Key | Shape | Units | Description | +|---|---|---|---| +| `"time"` | `(N,)` | s | Timestamps (need not be evenly spaced) | +| `"pos"` | `(N, 3)` | m | Position in world frame | +| `"quat"` | `(N, 4)` | — | Orientation quaternion (xyzw) | +| `"cmd_rpy"` | `(N, 3)` | rad | Commanded roll/pitch/yaw | +| `"cmd_f"` | `(N,)` | N | Commanded collective thrust | + +After `preprocessing` + [`derivatives_svf`][crazyflow.dynamics.utils.data_utils.derivatives_svf], the dict is augmented with filtered signals and numerical derivatives. The identification functions read `SVF_vel`, `SVF_acc`, `SVF_quat`, `SVF_cmd_f` (translation) and `SVF_rpy`, `SVF_cmd_rpy` (rotation). + +## Full pipeline + +```{ .python notest } +from crazyflow.dynamics.utils.data_utils import preprocessing, derivatives_svf +from crazyflow.dynamics.utils.identification import sys_id_translation, sys_id_rotation + +# Step 1 — assemble raw data dict from your flight log +data = { + "time": time_array, # (N,) seconds + "pos": pos_array, # (N, 3) metres + "quat": quat_array, # (N, 4) xyzw + "cmd_rpy": cmd_rpy_array, # (N, 3) radians + "cmd_f": cmd_f_array, # (N,) Newtons +} + +# Step 2 — outlier removal, quaternion normalisation, RPY calculation +data = preprocessing(data) + +# Step 3 — low-pass filter and compute time derivatives via State Variable Filter +data = derivatives_svf(data) + +# Step 4 — fit translational parameters +trans_params = sys_id_translation( + dynamics="so_rpy_rotor_drag", + mass=0.0319, # drone mass in kg — measure this directly + data=data, + verbose=0, # 0 = silent, 1 = progress, 2 = full optimizer output + plot=True, # show fit vs. measured plots +) +# Returns: {'cmd_f_coef': ..., 'thrust_time_coef': ..., +# 'drag_xy_coef': ..., 'drag_z_coef': ...} + +# Step 5 — fit rotational parameters +rot_params = sys_id_rotation(data=data, verbose=0, plot=True) +# Returns: {'rpy_coef': (3,), 'rpy_rates_coef': (3,), 'cmd_rpy_coef': (3,)} +``` + +See [`sys_id_translation`][crazyflow.dynamics.utils.identification.sys_id_translation] and [`sys_id_rotation`][crazyflow.dynamics.utils.identification.sys_id_rotation] in the API reference for the full argument list. + +## Validation + +To check that the identified parameters generalise to unseen flight regimes, collect a second dataset of different trajectories and pass it as `data_validation`. RMSE and R² are then reported on both the training data and the validation data. + +```{ .python notest } +# Preprocess the validation dataset independently — it must come from +# different trajectories, not a split of the same recording. +data_valid = preprocessing(validation_raw_data) +data_valid = derivatives_svf(data_valid) + +trans_params = sys_id_translation( + dynamics="so_rpy_rotor_drag", + mass=0.0319, + data=data, + data_validation=data_valid, + plot=True, +) +``` + +## Using identified parameters + +Once you have the identified coefficients, add them to the relevant `params.toml` file under a new drone name. Each dynamics sub-package ships its own `params.toml` — for example `crazyflow/dynamics/so_rpy_rotor_drag/params.toml` — and `load_params` reads from it when you call `parametrize`. Add a new section using the TOML table syntax: + +```toml +[my_drone] +cmd_f_coef = 0.983 # from trans_params["cmd_f_coef"] +thrust_time_coef = 0.121 # from trans_params["thrust_time_coef"] +drag_matrix = [[-0.0147, 0.0, 0.0], + [0.0, -0.0147, 0.0], + [0.0, 0.0, -0.0128]] # diag([drag_xy, drag_xy, drag_z]) +rpy_coef = [-245.67, -245.67, -227.78] # from rot_params["rpy_coef"] +rpy_rates_coef = [-17.32, -17.32, -25.63] # from rot_params["rpy_rates_coef"] +cmd_rpy_coef = [196.18, 196.18, 390.27] # from rot_params["cmd_rpy_coef"] +``` + +!!! note + `sys_id_translation` returns `drag_xy_coef` and `drag_z_coef` as scalars. Assemble the diagonal `drag_matrix` manually: `[drag_xy, drag_xy, drag_z]` on the diagonal. + +Once the entry is in the TOML file, load the dynamics as usual: + +```{ .python notest } +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.so_rpy_rotor_drag import dynamics + +dynamics = parametrize(dynamics, drone="my_drone") +``` + +Support for new drones can be added to the shared parameter files via a pull request on [GitHub](https://github.com/learnsyslab/crazyflow). + +## Which dynamics to identify + +Choose based on which physical effects you need to capture: + +- **`so_rpy`** — identifies only `cmd_f_coef`; no motor dynamics, no drag. Fastest to calibrate, good for slow flight. +- **`so_rpy_rotor`** — adds `thrust_time_coef` to model motor spin-up delay. Better for agile maneuvers. +- **`so_rpy_rotor_drag`** — adds `drag_xy_coef` and `drag_z_coef`. Best accuracy at higher speeds where aerodynamic drag is significant. + +--- diff --git a/docs/user-guide/gymnasium-envs.md b/docs/user-guide/gymnasium-envs.md index 35e7aa0..19b8cdb 100644 --- a/docs/user-guide/gymnasium-envs.md +++ b/docs/user-guide/gymnasium-envs.md @@ -38,9 +38,9 @@ All environments accept these common arguments: |---|---|---| | `num_envs` | 1 | Number of parallel environments | | `max_episode_time` | 10.0 | Episode length before truncation, seconds | -| `physics` | `Physics.so_rpy` | Physics model | -| `drone_model` | `"cf2x_L250"` | Drone configuration | -| `freq` | 500 | Physics frequency, Hz | +| `dynamics` | `Dynamics.so_rpy` | Dynamics | +| `drone` | `"cf2x_L250"` | Drone configuration | +| `freq` | 500 | Dynamics frequency, Hz | | `device` | `"cpu"` | `"cpu"` or `"gpu"` | | `reset_randomization` | `None` | Optional `(SimData, SimData, mask) → SimData` function applied at reset (base `DroneEnv` only) | diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index bd7ed22..21663c9 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -5,8 +5,8 @@ In-depth documentation for every part of the simulator. - [Simulator Overview](sim-overview.md) — `SimData` layout, worlds, drones, and the data convention - [Object-Oriented API](oo-api.md) — `Sim` class, control methods, rendering, and reset - [Functional API](functional-api.md) — purely functional interface for JAX transformations -- [Physics Models](physics-models.md) — first-principles vs. fitted models, when to use each -- [Control Modes](control-modes.md) — state, attitude, force/torque, and rotor velocity control +- [Dynamics](dynamics/index.md) — first-principles vs. fitted dynamics, when to use each +- [Control Modes](control/index.md) — state, attitude, force/torque, and rotor velocity control - [Pipelines](pipelines.md) — composable step and reset pipelines, randomization, and disturbances - [Visualization](visualization.md) — rendering modes, cameras, raycasting, and materials - [MuJoCo Integration](mujoco.md) — MJCF scene construction, adding objects, and sync internals diff --git a/docs/user-guide/mujoco.md b/docs/user-guide/mujoco.md index 8cbdc22..9fff624 100644 --- a/docs/user-guide/mujoco.md +++ b/docs/user-guide/mujoco.md @@ -1,6 +1,6 @@ # MuJoCo Integration -Crazyflow focuses on drone physics and controllers. However, we still want to provide rendering and collision checking, and to do that we leverage [MuJoCo](https://mujoco.org/) and its JAX port [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html). We keep an MJX representation of the scene in sync with Crazyflow's physics state and invoke MJX functions where needed: collision queries, forward kinematics, and sensor rendering. GUI rendering uses the CPU-side MuJoCo renderer directly. +Crazyflow focuses on drone dynamics and controllers. However, we still want to provide rendering and collision checking, and to do that we leverage [MuJoCo](https://mujoco.org/) and its JAX port [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html). We keep an MJX representation of the scene in sync with Crazyflow's dynamics state and invoke MJX functions where needed: collision queries, forward kinematics, and sensor rendering. GUI rendering uses the CPU-side MuJoCo renderer directly. ## MuJoCo and MJX objects @@ -13,25 +13,37 @@ Crazyflow maintains two parallel representations at all times: | `sim.mjx_model` | `mjx.Model` | JAX pytree of the model (static, shared across worlds) | | `sim.mjx_data` | `mjx.Data` | JAX pytree of the scene state, batched over `n_worlds` | -`mjx_data` does not hold the physics state. It holds the scene geometry state (body transforms, contact distances, camera positions), derived from `sim.data` through an explicit sync step whenever rendering or collision queries are needed. +`mjx_data` does not hold the dynamics state. It holds the scene geometry state (body transforms, contact distances, camera positions), derived from `sim.data` through an explicit sync step whenever rendering or collision queries are needed. ## MJCF and scene construction The scene is built programmatically from MJCF (MuJoCo's XML format) at `Sim` construction time using the `MjSpec` API. The process is: 1. Load the base scene from `crazyflow/scene.xml` (floor, lighting, and sky). -2. Load the drone MJCF from the `drone-models` package. +2. Load the drone MJCF bundled with `crazyflow.drones` (under `crazyflow/drones`). 3. Mark the drone body as mocap. Mocap bodies are kinematically driven by external position and quaternion updates rather than joints, which avoids the O(nv²) cost of computing constraint matrices and saves memory. 4. Attach one copy per drone to a frame in the world body. 5. Compile the spec into `mj_model`, then convert to `mjx_model` and `mjx_data` via `mjx.put_model` and `mjx.put_data`. Vmap `mjx_data` across `n_worlds`. The spec is accessible as `sim.spec` before compilation, and `sim.mj_model` / `sim.mjx_model` after. +## Fused drone model (`fused_mjx_model`) + +Each drone MJCF defines two bodies. `drone` has separate visual meshes for the PCB, motors, propellers, LEDs, and battery. `drone_fused` bakes the visuals into a single mesh. Passing `fused_mjx_model=True` to `Sim` selects the fused body: + +```{ .python } +from crazyflow.sim import Sim + +sim = Sim(n_worlds=1, n_drones=1, fused_mjx_model=True) +``` + +The difference between the two is purely visual. The fused body consists of a single geom with one mesh, which shrinks its memory footprint. The cost is visual detail. Use it for large swarms or headless runs, and keep the default for detailed rendering. + ## Adding objects to the scene Custom geometry (gates, obstacles, walls, or any MJCF body) can be added by editing `sim.spec` and calling `sim.build_mjx()`. The new geometry is available for collision and rendering but has no effect on the drone dynamics, which are computed independently in JAX. -```{ .python notest } +```{ .python } import mujoco from crazyflow.sim import Sim @@ -76,7 +88,7 @@ If you mark an attached body as mocap (`attached.mocap = True`), its position ca ## Synchronization -The JAX physics pipeline writes to `sim.data` but never touches `sim.mjx_data`. `mjx_data` is only needed for collision queries and rendering, which require current body transforms. To avoid computing those on every physics step, Crazyflow tracks a `mjx_synced` flag in `sim.data.core`. +The JAX dynamics pipeline writes to `sim.data` but never touches `sim.mjx_data`. `mjx_data` is only needed for collision queries and rendering, which require current body transforms. To avoid computing those on every dynamics step, Crazyflow tracks a `mjx_synced` flag in `sim.data.core`. After `sim.step()` or `sim.reset()`, `mjx_synced` is set to `False`. The `sim.render()` and `sim.contacts()` methods check the flag; if stale, they call `sync_sim2mjx()` once and set it back to `True`. @@ -86,33 +98,33 @@ After `sim.step()` or `sim.reset()`, `mjx_synced` is set to `False`. The `sim.re 2. `jax.vmap(mjx.kinematics)` to propagate body transforms through the kinematic tree. 3. `jax.vmap(mjx.camlight)` and `jax.vmap(mjx.collision)` for rendering and contact detection respectively. -These run only once per render or contact call, regardless of how many physics steps were taken since the last sync. +These run only once per render or contact call, regardless of how many dynamics steps were taken since the last sync. -```{ .python notest } +```{ .python continuation } for i in range(10): - sim.step(5) # JAX physics only, mjx_synced = False + sim.step(5) # JAX dynamics only, mjx_synced = False if i % 5 == 0: - sim.render() # syncs once: kinematics + camlight + collision + sim.render(mode="rgb_array") # syncs once: kinematics + camlight + collision ``` ## Advanced: the sync flag and avoiding redundant MJX calls -`sync_sim2mjx` runs kinematics, collision detection, and camera transforms in one shot. The `mjx_synced` flag ensures this happens at most once between physics steps: once the flag is set, any further calls to `sim.render()` or `sim.contacts()` within the same tick skip the sync entirely and operate on the already-computed MJX state. The flag is only cleared when `sim.data` actually changes, so if the physics state has not advanced, the expensive MJX operations are not repeated. +`sync_sim2mjx` runs kinematics, collision detection, and camera transforms in one shot. The `mjx_synced` flag ensures this happens at most once between dynamics steps: once the flag is set, any further calls to `sim.render()` or `sim.contacts()` within the same tick skip the sync entirely and operate on the already-computed MJX state. The flag is only cleared when `sim.data` actually changes, so if the dynamics state has not advanced, the expensive MJX operations are not repeated. This means the order of calls matters. Grouping all rendering and contact queries together after a step lets them share a single sync: -```{ .python notest } +```{ .python continuation } sim.step(5) -contacts = sim.contacts() # sync runs here -sim.render() # flag already set, no second sync +contacts = sim.contacts() # sync runs here +sim.render(mode="rgb_array") # flag already set, no second sync ``` Interleaving a step between them forces two syncs: -```{ .python notest } -contacts = sim.contacts() # sync runs here -sim.step(5) # flag cleared -sim.render() # sync runs again +```{ .python continuation } +contacts = sim.contacts() # sync runs here +sim.step(5) # flag cleared +sim.render(mode="rgb_array") # sync runs again ``` ## Advanced: fusing mjx_data into a contact check function @@ -123,8 +135,11 @@ The solution is to **close over** `mjx_data` rather than pass it as an argument. The drone racing environment in [lsy_drone_racing](https://github.com/learnsyslab/lsy_drone_racing) uses this pattern to build a contact check function: -```{ .python notest } +```{ .python continuation } +from jax import Array + from crazyflow.sim.sim import sync_sim2mjx +from crazyflow.sim.data import SimData _mjx_data = sim.mjx_data # captured in closure diff --git a/docs/user-guide/oo-api.md b/docs/user-guide/oo-api.md index 8bf8e5a..b703a7e 100644 --- a/docs/user-guide/oo-api.md +++ b/docs/user-guide/oo-api.md @@ -10,18 +10,18 @@ The `Sim` class is the main entry point. It provides a Python-level control loop All configuration is fixed at construction time. ```python -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.sim.integration import Integrator from crazyflow.control import Control sim = Sim( n_worlds=1, n_drones=1, - drone_model="cf2x_L250", # Crazyflie 2.x with L250 props - physics=Physics.first_principles, + drone="cf2x_L250", # Crazyflie 2.x with L250 props + dynamics=Dynamics.first_principles, control=Control.state, integrator=Integrator.rk4, - freq=500, # physics update rate, Hz + freq=500, # dynamics update rate, Hz state_freq=100, # state controller rate, Hz attitude_freq=500, # attitude controller rate, Hz device="cpu", @@ -35,11 +35,11 @@ Key constructor arguments: |---|---|---| | `n_worlds` | 1 | Number of independent parallel environments | | `n_drones` | 1 | Drones per world | -| `drone_model` | `"cf2x_L250"` | Drone configuration (see `drone_models.available_drones`) | -| `physics` | `Physics.default` | Physics model | +| `drone` | `"cf2x_L250"` | Drone configuration (see `crazyflow.available_drones`) | +| `dynamics` | `Dynamics.default` | Dynamics | | `control` | `Control.default` | Control mode | | `integrator` | `Integrator.default` | Numerical integrator | -| `freq` | 500 | Physics frequency, Hz | +| `freq` | 500 | Dynamics frequency, Hz | | `device` | `"cpu"` | `"cpu"` or `"gpu"` | ## Control methods @@ -72,10 +72,10 @@ Commands roll, pitch, yaw setpoints (rad) and a collective thrust (N). This leve ```python import numpy as np -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(n_worlds=1, n_drones=1, control=Control.attitude, physics=Physics.so_rpy) +sim = Sim(n_worlds=1, n_drones=1, control=Control.attitude, dynamics=Dynamics.so_rpy) sim.reset() # [roll, pitch, yaw, collective_thrust_N] @@ -88,14 +88,14 @@ sim.step(sim.freq // sim.control_freq) ### Force-torque control -Direct force and torque input, useful for testing dynamics or custom controllers. Requires `Physics.first_principles`. +Direct force and torque input, useful for testing dynamics or custom controllers. Requires `Dynamics.first_principles`. ```python import numpy as np -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(n_worlds=1, n_drones=1, control=Control.force_torque, physics=Physics.first_principles) +sim = Sim(n_worlds=1, n_drones=1, control=Control.force_torque, dynamics=Dynamics.first_principles) sim.reset() # [collective_force_N, torque_x_Nm, torque_y_Nm, torque_z_Nm] @@ -108,14 +108,14 @@ sim.step(1) ### Rotor velocity control -The lowest level: directly command each motor's RPM. Requires `Physics.first_principles`. +The lowest level: directly command each motor's RPM. Requires `Dynamics.first_principles`. ```python import numpy as np -from crazyflow.sim import Sim, Physics +from crazyflow.sim import Sim, Dynamics from crazyflow.control import Control -sim = Sim(n_worlds=1, n_drones=1, control=Control.rotor_vel, physics=Physics.first_principles) +sim = Sim(n_worlds=1, n_drones=1, control=Control.rotor_vel, dynamics=Dynamics.first_principles) sim.reset() # [rpm_motor_0, rpm_motor_1, rpm_motor_2, rpm_motor_3] @@ -127,7 +127,7 @@ sim.step(1) ## Stepping and resetting -`sim.step(n_steps)` advances the simulation by `n_steps` physics ticks. On each tick, the full step pipeline runs, including the control stack. Controllers fire at their configured rate (e.g. the state controller at `state_freq`, the attitude controller at `attitude_freq`), not on every physics tick. Between controller ticks, the previously staged command is held. +`sim.step(n_steps)` advances the simulation by `n_steps` dynamics ticks. On each tick, the full step pipeline runs, including the control stack. Controllers fire at their configured rate (e.g. the state controller at `state_freq`, the attitude controller at `attitude_freq`), not on every dynamics tick. Between controller ticks, the previously staged command is held. Passing more steps to a single `step(n_steps)` call is more efficient than multiple `step(1)` calls: XLA compiles the full loop into a single kernel. If you have staged a control command and do not need to set a new one, you can advance the simulation by any number of steps and the controllers will continue firing at the correct rate. @@ -144,7 +144,7 @@ from crazyflow.control import Control sim = Sim(n_worlds=4, n_drones=1, control=Control.state) sim.reset() # reset all worlds -# Stage a command and advance 50 physics steps (controllers fire at their rate) +# Stage a command and advance 50 dynamics steps (controllers fire at their rate) cmd = np.zeros((4, 1, 13), dtype=np.float32) cmd[..., 2] = 0.5 sim.state_control(cmd) diff --git a/docs/user-guide/physics-models.md b/docs/user-guide/physics-models.md deleted file mode 100644 index 754a902..0000000 --- a/docs/user-guide/physics-models.md +++ /dev/null @@ -1,88 +0,0 @@ -# Physics Models - -Crazyflow supports four physics models, selectable via the `Physics` enum. All models share the same state representation and control interface, so you can swap them at construction time without changing any other code. - -```python -from crazyflow.sim import Sim, Physics - -sim = Sim(physics=Physics.first_principles) -``` - -## Available models - -| Model | Enum value | Command input | Description | -|---|---|---|---| -| First principles | `Physics.first_principles` | Rotor RPM | Full analytical model with identified parameters | -| SO(3) + RPY | `Physics.so_rpy` | Roll/pitch/yaw + thrust | Simplified fitted model | -| SO(3) + RPY + rotor | `Physics.so_rpy_rotor` | Roll/pitch/yaw + thrust | Adds first-order rotor dynamics | -| SO(3) + RPY + rotor + drag | `Physics.so_rpy_rotor_drag` | Roll/pitch/yaw + thrust | Adds translational and rotational drag | - -`Physics.default` resolves to `Physics.first_principles`. - -## First-principles model - -The first-principles model derives forces and torques analytically from motor speeds using identified physical parameters: mass, arm length, propeller constants, and the full inertia tensor. It operates at the rotor-velocity level and is the most accurate model for sim-to-real transfer. - -```python -from crazyflow.sim import Sim, Physics -from crazyflow.control import Control - -# Force-torque and rotor_vel control modes require first_principles -sim = Sim( - physics=Physics.first_principles, - control=Control.rotor_vel, -) -sim.reset() -``` - -Parameters accessible through `sim.data.params`: - -| Parameter | Description | -|---|---| -| `mass` | Drone mass, kg | -| `J` | Inertia matrix, kg·m² | -| `L` | Motor arm length, m | -| `rpm2thrust` | Thrust coefficient, N/(RPM²) | -| `rpm2torque` | Torque coefficient, Nm/(RPM²) | -| `mixing_matrix` | Maps rotor RPMs² to [thrust, tx, ty, tz] | -| `rotor_dyn_coef` | First-order rotor time constant | - -## Fitted models (so_rpy family) - -The `so_rpy` models are identified from flight data using a small number of flight minutes. They take higher-level commands (roll/pitch/yaw setpoints + collective thrust in Newtons) and are faster to simulate because they skip the rotor-velocity level. - -These models are a good choice when: - -- You are training RL agents and want speed over fidelity -- Your controller outputs attitude targets (as most Crazyflie firmware does) -- You do not need rotor-level detail - -```python -from crazyflow.sim import Sim, Physics -from crazyflow.control import Control - -sim = Sim( - physics=Physics.so_rpy_rotor_drag, # most accurate of the fitted family - control=Control.attitude, -) -sim.reset() -``` - -The `so_rpy_rotor_drag` variant includes translational drag, which captures the velocity-dependent deceleration effect visible in aggressive flights. It is the recommended fitted model for sim-to-real experiments. - -## Control mode compatibility - -| Physics model | `Control.state` | `Control.attitude` | `Control.force_torque` | `Control.rotor_vel` | -|---|---|---|---|---| -| `first_principles` | ✓ | ✓ | ✓ | ✓ | -| `so_rpy` | ✓ | ✓ | ✗ | ✗ | -| `so_rpy_rotor` | ✓ | ✓ | ✗ | ✗ | -| `so_rpy_rotor_drag` | ✓ | ✓ | ✗ | ✗ | - -!!! warning - Using `Control.force_torque` or `Control.rotor_vel` with a fitted model raises `ConfigError` at construction time. - -## Next steps - -- [Control Modes](control-modes.md) — command shapes and the control hierarchy -- [Object-Oriented API](oo-api.md) — full constructor arguments diff --git a/docs/user-guide/pipelines.md b/docs/user-guide/pipelines.md index fde3dbe..1e07fbd 100644 --- a/docs/user-guide/pipelines.md +++ b/docs/user-guide/pipelines.md @@ -22,7 +22,7 @@ Both pipelines are constructed at `Sim` initialisation and compiled into a singl `sim.step_pipeline` contains four stages by default: 1. **Control functions** — convert the staged command through the control hierarchy (state → attitude → force/torque → rotor velocities, depending on the selected mode) -2. **Integrator** (`integration`) — advance the ODE one physics step (Euler, RK4, or symplectic Euler) +2. **Integrator** (`integration`) — advance the ODE one dynamics step (Euler, RK4, or symplectic Euler) 3. **Step counter** (`increment_steps`) — increment `data.core.steps` 4. **Floor clip** (`clip_floor_pos`) — prevent drones from passing through the floor diff --git a/docs/user-guide/sim-overview.md b/docs/user-guide/sim-overview.md index 54cf3d1..e36af8e 100644 --- a/docs/user-guide/sim-overview.md +++ b/docs/user-guide/sim-overview.md @@ -4,8 +4,8 @@ Crazyflow organises simulation state into a two-dimensional batch: **worlds × drones**. -- **`n_worlds`** — number of independent simulation environments. Each world has its own physics state and evolves independently. Use this to run domain randomisation, parallel RL rollouts, or MPPI sampling. -- **`n_drones`** — number of drones per world. All drones in a world share the same physics tick but have independent states. +- **`n_worlds`** — number of independent simulation environments. Each world has its own dynamics state and evolves independently. Use this to run domain randomisation, parallel RL rollouts, or MPPI sampling. +- **`n_drones`** — number of drones per world. All drones in a world share the same dynamics tick but have independent states. Every state array has shape `(n_worlds, n_drones, feature_dim)`. To read the position of drone 0 in world 2: @@ -25,7 +25,7 @@ All simulation state is stored in `sim.data`, a `SimData` pytree. The main sub-t | Field | Type | Description | |---|---|---| | `states` | `SimState` | Current kinematic state of every drone | -| `states_deriv` | `SimStateDeriv` | Time derivatives computed by the physics model | +| `states_deriv` | `SimStateDeriv` | Time derivatives computed by the dynamics | | `controls` | `SimControls` | Staged commands and controller state | | `params` | `SimParams` | Physical parameters (mass, inertia, motor constants, …) | | `core` | `SimCore` | Metadata: step count, frequency, RNG key, device | @@ -63,9 +63,9 @@ sim.data = sim.data.replace(states=sim.data.states.replace(pos=new_pos)) ## Simulation frequency and the control stack -`freq` sets the physics update rate in Hz. The control stack is executed as part of each physics step, but controllers fire at their own sub-frequency rather than every tick. For example, with `freq=500` and `state_freq=100`, the state (Mellinger) controller runs every 5 physics steps, and the attitude controller runs at `attitude_freq`. +`freq` sets the dynamics update rate in Hz. The control stack is executed as part of each dynamics step, but controllers fire at their own sub-frequency rather than every tick. For example, with `freq=500` and `state_freq=100`, the state (Mellinger) controller runs every 5 dynamics steps, and the attitude controller runs at `attitude_freq`. -This means you can advance multiple physics steps in a single `sim.step(n_steps)` call and the control stack will execute at the correct rate automatically, with no manual sub-stepping required. This is also what makes fusing many steps into a single compiled call efficient. +This means you can advance multiple dynamics steps in a single `sim.step(n_steps)` call and the control stack will execute at the correct rate automatically, with no manual sub-stepping required. This is also what makes fusing many steps into a single compiled call efficient. ```python import numpy as np @@ -76,7 +76,7 @@ sim = Sim(freq=500, control=Control.state) sim.reset() cmd = np.zeros((1, 1, 13), dtype=np.float32) sim.state_control(cmd) -sim.step(sim.freq // sim.control_freq) # 500 // 100 = 5 physics steps, controller fires once +sim.step(sim.freq // sim.control_freq) # 500 // 100 = 5 dynamics steps, controller fires once ``` ## The step and reset pipelines diff --git a/docs/user-guide/visualization.md b/docs/user-guide/visualization.md index cdcfe90..9daec03 100644 --- a/docs/user-guide/visualization.md +++ b/docs/user-guide/visualization.md @@ -1,6 +1,6 @@ # Visualization -Crazyflow supports onscreen interactive rendering and offscreen RGB/depth capture through MuJoCo's renderer. Rendering is fully decoupled from the physics step: call `sim.render()` at any frequency independently of how fast the simulation runs. +Crazyflow supports onscreen interactive rendering and offscreen RGB/depth capture through MuJoCo's renderer. Rendering is fully decoupled from the dynamics step: call `sim.render()` at any frequency independently of how fast the simulation runs. @@ -84,7 +84,7 @@ sim.render(world=3) # render world 3 ## Sync and performance -Rendering triggers an implicit synchronization between the JAX physics state (`sim.data`) and the MuJoCo render buffers (`sim.mjx_data`). This sync computes full forward kinematics, camera transforms, and collision geometry — it is the most expensive operation per render call. See [MuJoCo Integration](mujoco.md#synchronization) for details on how to avoid redundant syncs. +Rendering triggers an implicit synchronization between the JAX dynamics state (`sim.data`) and the MuJoCo render buffers (`sim.mjx_data`). This sync computes full forward kinematics, camera transforms, and collision geometry — it is the most expensive operation per render call. See [MuJoCo Integration](mujoco.md#synchronization) for details on how to avoid redundant syncs. ## Next steps diff --git a/examples/contacts/contacts.py b/examples/contacts/contacts.py index 8771cf6..b0992fc 100644 --- a/examples/contacts/contacts.py +++ b/examples/contacts/contacts.py @@ -1,13 +1,13 @@ import numpy as np -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim from crazyflow.sim.sim import use_box_collision def main(): """Spawn multiple drones in multiple worlds and check for contacts.""" n_worlds, n_drones = 2, 3 - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device="cpu") + sim = Sim(n_worlds=n_worlds, n_drones=n_drones, dynamics=Dynamics.so_rpy, device="cpu") use_box_collision(sim, enable=True) # Enable box collision for all drones fps = 60 diff --git a/examples/control/hover.py b/examples/control/hover.py index 36241d5..c597e59 100644 --- a/examples/control/hover.py +++ b/examples/control/hover.py @@ -1,14 +1,14 @@ import numpy as np from crazyflow.control import Control -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim def main(): sim = Sim( n_worlds=1, n_drones=1, - physics=Physics.first_principles, + dynamics=Dynamics.first_principles, control=Control.state, freq=500, attitude_freq=500, diff --git a/examples/control/sampling.py b/examples/control/sampling.py index 3fde3ea..f0a83d8 100644 --- a/examples/control/sampling.py +++ b/examples/control/sampling.py @@ -13,13 +13,13 @@ import jax import jax.numpy as jnp import numpy as np -from drone_models.core import load_params -from drone_models.transform import motor_force2rotor_vel from jax import Array from jax.lax import scan from crazyflow.control import Control -from crazyflow.sim import Physics, Sim +from crazyflow.control.transform import motor_force2rotor_vel +from crazyflow.drones import load_params +from crazyflow.sim import Dynamics, Sim from crazyflow.sim.data import SimData from crazyflow.sim.visualize import draw_capsule, draw_line @@ -31,7 +31,7 @@ DEVICE_CONTROLLER = "cpu" # Simulation configuration -DRONE_MODEL = "cf21B_500" +DRONE = "cf21B_500" DURATION = 10.0 FPS = 60 RENDER = True @@ -142,7 +142,7 @@ def rollout_sim( quat=rollout_data.states.quat.at[...].set(obs["quat"]), vel=rollout_data.states.vel.at[...].set(obs["vel"]), ang_vel=rollout_data.states.ang_vel.at[...].set(obs["ang_vel"]), - # The reduced model stores collective thrust in its rotor_vel state. + # The reduced dynamics store collective thrust in its rotor_vel state. rotor_vel=rollout_data.states.rotor_vel.at[...].set(obs["collective_thrust"]), ) data = rollout_data.replace(states=states) @@ -222,16 +222,11 @@ def main() -> None: obstacles = obstacle_grid() # Set up the main sim - sim = Sim( - n_worlds=1, - drone_model=DRONE_MODEL, - physics=Physics.first_principles, - control=Control.attitude, - ) + sim = Sim(n_worlds=1, drone=DRONE, dynamics=Dynamics.first_principles, control=Control.attitude) sim.max_visual_geom = 100_000 # To be able to show all rollouts sim.reset() start_pos = lissajous_reference(0.0)["pos"] - drone_params = load_params("first_principles", DRONE_MODEL) + drone_params = load_params(DRONE) hover_thrust_value = np.asarray(drone_params["mass"] * 9.81, dtype=np.float32) hover_rotor_vel = motor_force2rotor_vel( np.full(4, hover_thrust_value / 4.0, dtype=np.float32), drone_params["rpm2thrust"] @@ -249,8 +244,8 @@ def main() -> None: rollout_simulator = Sim( n_worlds=N_SAMPLES, device=controller_device.platform, - drone_model=DRONE_MODEL, - physics=Physics.so_rpy_rotor_drag, + drone=DRONE, + dynamics=Dynamics.so_rpy_rotor_drag, control=Control.attitude, freq=rollout_freq, attitude_freq=rollout_freq, @@ -303,7 +298,7 @@ def main() -> None: "quat": sim.data.states.quat[0, 0], "vel": sim.data.states.vel[0, 0], "ang_vel": sim.data.states.ang_vel[0, 0], - # Thrust is not observable and difficult to estimate, so use the thrust model. + # Thrust is not observable and difficult to estimate, so use the thrust dynamics. "collective_thrust": thrust_estimate, } action, key, mean_controls, best_positions, sampled_positions = control( diff --git a/examples/control/spiral.py b/examples/control/spiral.py index 598a9cf..8f7b556 100644 --- a/examples/control/spiral.py +++ b/examples/control/spiral.py @@ -13,7 +13,7 @@ def control(start_xy: np.ndarray, t: float) -> np.ndarray: def main(): - sim = Sim(n_drones=4, control=Control.state, integrator="rk4", physics="first_principles") + sim = Sim(n_drones=4, control=Control.state, integrator="rk4", dynamics="first_principles") sim.reset() duration = 5.0 fps = 60 diff --git a/examples/jax/cache.py b/examples/jax/cache.py index 529347b..3d5ab66 100644 --- a/examples/jax/cache.py +++ b/examples/jax/cache.py @@ -5,7 +5,7 @@ However, the cache is not persistent between Python sessions. The Sim class uses many jitted functions internally, particularly in the step() method which -compiles a chain of physics and control functions. On the first step() call, the entire chain is +compiles a chain of dynamics and control functions. On the first step() call, the entire chain is compiled. After the Python session ends, the cached functions get lost. However, we can enable a persistent diff --git a/examples/jax/gradient.py b/examples/jax/gradient.py index 14b1b1e..5f735be 100644 --- a/examples/jax/gradient.py +++ b/examples/jax/gradient.py @@ -5,12 +5,12 @@ from numpy.typing import NDArray from crazyflow.control import Control -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim from crazyflow.sim.data import SimData def main(): - sim = Sim(control=Control.attitude, physics=Physics.first_principles, attitude_freq=50) + sim = Sim(control=Control.attitude, dynamics=Dynamics.first_principles, attitude_freq=50) # Remove clipping floor function which kills gradients sim_step = sim.build_step_fn() # If the drone starts on the floor, the gradient gets killed by the floor clipping function. We diff --git a/examples/plugins/estimation.py b/examples/plugins/estimation.py index 088b1be..1686129 100644 --- a/examples/plugins/estimation.py +++ b/examples/plugins/estimation.py @@ -10,9 +10,9 @@ import jax import jax.numpy as jnp import numpy as np -from drone_models.transform import motor_force2rotor_vel from crazyflow import Sim +from crazyflow.control.transform import motor_force2rotor_vel from crazyflow.sim.pipeline import insert_fn_after, prepend_fn from crazyflow.sim.visualize import draw_line, draw_points @@ -170,7 +170,7 @@ def main(noisy: bool = False, render: bool = True) -> None: prepend_fn(sim.step_pipeline, simulate_uwb) insert_fn_after(sim.step_pipeline, "simulate_uwb", estimate_state) insert_fn_after(sim.step_pipeline, "estimate_state", use_estimate_for_control) - insert_fn_after(sim.step_pipeline, "step_force_torque_controller", restore_ground_truth) + insert_fn_after(sim.step_pipeline, "force_torque_controller", restore_ground_truth) sim.build_default_data() sim.build_step_fn() diff --git a/examples/rendering/cameras.py b/examples/rendering/cameras.py index ef68274..c07b59e 100644 --- a/examples/rendering/cameras.py +++ b/examples/rendering/cameras.py @@ -8,9 +8,9 @@ from matplotlib import animation from crazyflow.control import Control +from crazyflow.dynamics import Dynamics from crazyflow.sim import Sim from crazyflow.sim.integration import Integrator -from crazyflow.sim.physics import Physics def control(t: float, t_tot: float) -> np.ndarray: @@ -65,8 +65,8 @@ def main(show_plot: bool = False, save_plot: bool = False): n_drones=1, control=Control.state, integrator=Integrator.rk4, - physics=Physics.first_principles, - drone_model="cf2x_T350", + dynamics=Dynamics.first_principles, + drone="cf2x_T350", ) add_smiley(sim) sim.reset() @@ -100,9 +100,9 @@ def update_frame(_): # noqa: ANN202 # mode: Either "human" for the regular window, "rgb_array" for an RGB array, # "depth_array" for a depth array, or "rgbd_tuple" for both at the same time. # camera: The name or id of the camera. The names are specified in the corresponding - # xml file in drone_models. For example, "fpv_cam:0" is the first-person view camera - # of the first drone, "track_cam:0" is the tracking camera of the first drone. - # Id -1 is the global camera. + # xml file in crazyflow/drones. For example, "fpv_cam:0" is the first-person view + # camera of the first drone, "track_cam:0" is the tracking camera of the first + # drone. Id -1 is the global camera. rgbd = sim.render( width=resolution[0], height=resolution[1], mode="rgbd_tuple", camera="fpv_cam:0" ) diff --git a/examples/rendering/led_deck.py b/examples/rendering/led_deck.py index aba4739..aa556a6 100644 --- a/examples/rendering/led_deck.py +++ b/examples/rendering/led_deck.py @@ -3,7 +3,7 @@ import numpy as np -from crazyflow.control.control import Control +from crazyflow.control import Control from crazyflow.sim import Sim from crazyflow.sim.visualize import change_material @@ -42,7 +42,7 @@ def main(): tmp.flush() tmp_path = Path(tmp.name) - sim = Sim(n_drones=25, drone_model="cf21B_500", control=Control.state, xml_path=tmp_path) + sim = Sim(n_drones=25, drone="cf21B_500", control=Control.state, xml_path=tmp_path) fps = 60 cmd = np.zeros((sim.n_worlds, sim.n_drones, 4)) cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81 diff --git a/examples/rendering/render.py b/examples/rendering/render.py index 45aae74..9b6055d 100644 --- a/examples/rendering/render.py +++ b/examples/rendering/render.py @@ -2,14 +2,14 @@ import numpy as np -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim from crazyflow.sim.visualize import draw_line def main(): """Spawn 25 drones in one world and render each with a trace behind it.""" n_worlds, n_drones = 1, 25 - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device="cpu") + sim = Sim(n_worlds=n_worlds, n_drones=n_drones, dynamics=Dynamics.so_rpy, device="cpu") fps = 60 cmd = np.zeros((sim.n_worlds, sim.n_drones, 4)) cmd[..., 3] = sim.data.params.mass[0, 0, 0] * 9.81 * 1.05 diff --git a/examples/symbolic.py b/examples/symbolic.py index 1a2a006..c722282 100644 --- a/examples/symbolic.py +++ b/examples/symbolic.py @@ -1,21 +1,22 @@ import casadi as cs import numpy as np +from crazyflow.dynamics import parametrize +from crazyflow.dynamics.so_rpy import symbolic_dynamics from crazyflow.sim import Sim -from crazyflow.sim.symbolic import symbolic_from_sim def main(): - # We can create a symbolic model directly from the simulation. Note that this will use the + # We can create a symbolic dynamics directly from the simulation. Note that this will use the # nominal parameters of the simulation and choose the control type based on the simulation. - sim = Sim(physics="so_rpy", freq=500) - X_dot, X, U, Y = symbolic_from_sim(sim) + sim = Sim(dynamics="so_rpy", freq=500) + X_dot, X, U, Y = parametrize(symbolic_dynamics, sim.drone)() assert X_dot.shape == (13, 1) assert X.shape == (13, 1) assert U.shape == (4, 1) # Attitude control assert Y.shape == (7, 1) # 3 for pos and 4 for quat - # To create a discrete-time model that you can integrate, you can use the integrator function + # To create a discrete-time dynamics that you can integrate, you can use the integrator function # from CasADi. fd = cs.integrator("fd", "cvodes", {"x": X, "p": U, "ode": X_dot}, 0, 1 / sim.freq) x0 = np.ones((13, 1)) diff --git a/pixi.lock b/pixi.lock index ab2e8ae..3449b24 100644 --- a/pixi.lock +++ b/pixi.lock @@ -27,7 +27,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -42,7 +42,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hb9d3cd8_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.33-pthreads_h94d23a6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-15.2.0-hdf11a46_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda @@ -55,13 +55,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.6-hdb14827_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py312h33ff503_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py312h50c33e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.13-hd63d673_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/qhull-2020.2-h434a139_5.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-17.0.1-py312h4c3975b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb03c661_1.conda @@ -83,8 +83,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.47.0-pyhd8ed1ab_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/b6/156a8de1e1b47694f0e7de6675866936608d45dc68388fd017d36f8693be/simplejson-4.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl @@ -92,6 +90,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/18/2a/d4cd8506d2044e082f8cd921be57392e6a9b5ccd3ffdf050362430a3d5d5/nvidia_cuda_cccl_cu12-12.9.27-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/7c/b7b24e10e5cb0213c85204d53fcd60d0568d986ea0001a00a815e14e01e1/tensorstore-0.1.84-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/ab/d7233c915b12c005655437c6c4cf0ae46cbbb2b20d743cb5e4881ad3104a/casadi-3.7.2-cp312-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/25/48/b54a06168a2190572a312bfe4ce443687773eb61367ced31e064953dd2f7/nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl @@ -99,34 +98,33 @@ environments: - pypi: https://files.pythonhosted.org/packages/33/40/79b0c64d44d6c166c0964ec1d803d067f4a145cca23e23925fd351d0e642/nvidia_cusolver_cu12-11.7.5.82-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/33/d1/8bb87d21e9aeb323cc03034f5eaf2c8f69841e40e4853c2627edf8111ed3/termcolor-3.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/69/2912ab63036e21c72748019e1d8e09e8a1fc3368b3e83fc27898a1858575/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/38/ed/b7728573156d70b6b094233b0f38d876fc37340826cf852347ec2c7ca8ca/msgpack-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3a/13/547360d81e6d88d58492968ffda9f9542854f11310ee556fef14260cc886/zipp-4.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/3e/85/1c12e849e4d50624e75496378a3fb168389f768d3ec7cb694fba873ff9a8/nvidia_nvshmem_cu12-3.7.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/50/32/e7ffa9c324ae260e5dbb4af2cd557bf7a8d155c8ac7b79a785fe1796fb92/nvidia_nccl_cu12-2.30.7-py3-none-manylinux_2_18_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5f/6f/e62b4dfc7ad6518e7eff2516f680d02a0f6eb62c0c212e152ca708a0085e/uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/65/b6/09b01cdbc15224e2850365192d17b7bdebb8bdbd8780ed221fcdf0d9a515/pandas-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/6b/c3/0e45ff4dce8401f6ea7c25d80d75738813a47f5ae2691e2478f2fd1e5e93/nvidia_nccl_cu12-2.30.4-py3-none-manylinux_2_18_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/7d/9d/1a383211b0967e702b9e84643986fb31bf35ca07bddc19e0cf139fd3291d/nvidia_cudnn_cu12-9.23.0.39-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/3f/cabd3c791ff5042df157609e00e96440ccaba69f72bccd8e3470d85fdd48/jax_cuda12_plugin-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/91/4d/2ca3ca9906ce6e05070f431c54d54ccbaf57a980cfa58032d35b0b0ac1f8/pyinstrument-5.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/95/f4/61e6996dd20481ee834f57a8e9dca28b1869366a135e0d42e2aa8493bdd4/nvidia_cufft_cu12-11.4.1.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/9e/da/36fa8307cc40889307fed415d70b67d35ec330ffce889a9c03cf8f616cfa/nvidia_nvshmem_cu12-3.6.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/a0/7e/e0a5d44bf070a1ff945050abc02ef1cff5ca9c6ab5dc6a16ab6322593a32/nvidia_cudnn_cu12-9.23.1.3-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl @@ -140,6 +138,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/d5/01/b2a88b6b73df933d5ab38583240c296684b626a8de3c3bb9a7c2fd356f08/mujoco-3.9.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f7/a1/47c08a81760cae84c4a4aa720f3fc1ce3bac6f7aafa5ab82c302d7946f07/jax_cuda12_pjrt-0.10.1-py3-none-manylinux_2_27_x86_64.whl @@ -158,7 +157,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-8_h4a7cf45_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -172,7 +171,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.33-pthreads_h94d23a6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.42.1-h5347b49_0.conda @@ -182,13 +181,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.6-hdb14827_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py313hf6604e3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py313h80991f8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.13-h6add32d_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.14-h6add32d_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py313h3dea7bd_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.5-hb03c661_1.conda @@ -204,35 +203,31 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/15/c0/0be24758891ef825f2065cd5db8741aaddabe3e248ee6acc5e8a80f04005/uvloop-0.22.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/2e/21a3ede87f0bf82d6c7bcb90480d50a6490eb974c6ab20881188e440957c/simplejson-4.1.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3f/5b/7120e22f6e22ca77283f4a086ab2e59d107f00bfc952116db41a015385fe/casadi-3.7.2-cp313-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5e/c6/82669e70cef67c803852285ba6f59d7e3d102983c0ab4be8269c14756677/tensorstore-0.1.84-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/df/a1495de78c1da3e8e93978dd177b04d18aaa7361452e30a3467c41c3b19e/mujoco-3.9.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl @@ -241,9 +236,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.5.20-hbd8a1cb_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-26.2-pyhc364b38_0.conda @@ -255,16 +252,17 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/bzip2-1.0.8-hd037594_9.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-78.3-hef89b57_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.19.1-hdfa7624_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.1.0-h1eee2c3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libblas-3.11.0-8_h51639a9_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.11.0-8_hb0561ab_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcxx-22.1.7-h55c6f16_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.25-hc11a715_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.5.2-hcf2aa1b_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgcc-15.2.0-hcbb3090_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran-15.2.0-h07b0088_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran5-15.2.0-hdae7583_19.conda @@ -274,7 +272,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libmpdec-4.0.0-h84a0fba_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libopenblas-0.3.33-openmp_he657e61_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.58-h132b30e_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.1-h1b79a29_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.2-h1ae2325_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.7.1-h4030677_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.6.0-h07db88b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.17.0-hdb1d25a_0.conda @@ -283,13 +281,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.6-h1d4f5a5_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.4.6-py313hce9b930_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.4-hd9e9057_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.2-hd24854e_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.3-hd24854e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-12.2.0-py313h45e5a15_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.13-h20e6be0_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.14-h448ec07_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.3-py313h65a2061_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.3-h46df422_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.15-h80928e0_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.17-h80928e0_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h010d191_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.12-hc919400_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxdmcp-1.1.5-hc919400_1.conda @@ -297,17 +295,16 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/b7/087fcbfe2a0a0b44e236c9853d7fa7c539db6b8c60ab5702fffd73be5a7c/casadi-3.7.2-cp313-none-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/10/03/8aeeb7458d22546bf64b5250ca1daeb5ff757d900e8e4a7476c6f0db843e/protobuf-7.35.1-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/83/4f3c6ef9bed01f384036c2030b3901cf075bbc8eff6e4529e502f0283ab5/tensorstore-0.1.84-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl @@ -315,15 +312,13 @@ environments: - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/89/8c/182a2a593195bfd39842ea68ebc084e20c850806117213f5a299dfc513d9/uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/9a/e5/54cb7c50ad5fdc1e0a86b7df4b135c2cbd5c4623605aa94466659098e8da/simplejson-4.1.1-cp313-cp313-macosx_11_0_arm64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b1/88/a29bca408a4c2db6c5bcf58a8b92c464660b7f846c559abd9110783574cb/mujoco-3.9.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -331,6 +326,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl @@ -345,7 +341,7 @@ environments: packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.5.0-py313h18e8e13_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.6.0-py313h18e8e13_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.2.0-py313hf159716_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.19.1-h0c24ade_1.conda @@ -354,7 +350,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-8_h4a7cf45_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -368,7 +364,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.33-pthreads_h94d23a6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.42.1-h5347b49_0.conda @@ -378,13 +374,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.6-hdb14827_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py313hf6604e3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py313h80991f8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.13-h6add32d_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.14-h6add32d_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py313h3dea7bd_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.5-hb03c661_1.conda @@ -409,8 +405,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl @@ -418,6 +412,8 @@ environments: - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1b/0e/bf298920729f216adcb002acf7ea01b90842603d2e4e2ce9b900d9ee8fab/nh3-0.3.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/20/2c/0622f20ff02b2ef32558733443805dc82fd4c275be01b2d19d14676f3a1b/cryptography-49.0.0-cp311-abi3-manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/91/30151a39f7570f448ed84529390628a651d7f27c87d73c9b887f8189695e/docutils-0.23-py3-none-any.whl @@ -428,28 +424,25 @@ environments: - pypi: https://files.pythonhosted.org/packages/42/77/de194443bf38daed9452139e960c632b0ef9f9a5dd9ce605fdf18ca9f1b1/id-1.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5e/c6/82669e70cef67c803852285ba6f59d7e3d102983c0ab4be8269c14756677/tensorstore-0.1.84-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7f/66/b15ce62552d84bbfcec9a4873ab79d993a1dd4edb922cbfccae192bd5b5f/jaraco.classes-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/95/38/0d29a6fd7d0d1373f0c0c88a04ba20e359b257753ac497564cd660fc1d55/cryptography-48.0.0-cp311-abi3-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/96/9a/982e48afcffcd727a9144506720ffd4224b6b7e355c98641866f38b7c043/jaraco_functools-4.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/97/1b/295bf2fa3e740131778065e5ffa2c481f0e7210182d408e9a2c244ff5b0c/readme_renderer-45.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/df/a1495de78c1da3e8e93978dd177b04d18aaa7361452e30a3467c41c3b19e/mujoco-3.9.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/a3/e137168c9c44d18eff0376253da9f1e9234d0239e0ee230d2fee6cea8e55/jeepney-0.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -461,12 +454,13 @@ environments: - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/e1/67/921ec3024056483db83953ae8e48079ad62b92db7880013ca77632921dd0/readme_renderer-44.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e8/3d/1087453384dbde46a8c7f9356eead2c58be8a7bf156bca40243377c85715/more_itertools-11.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/f2/58/bc8954bda5fcda97bd7c19be11b85f91973d67a706ed4a3aec33e7de22db/jaraco_context-6.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ff/9a/9afaade874b2fa6c752c36f1548f718b5b83af81ed9b76628329dab81c1b/rfc3986-2.0.0-py2.py3-none-any.whl osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.5.20-hbd8a1cb_0.conda @@ -487,19 +481,20 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.5.0-py313h7208f8c_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.6.0-py313h7208f8c_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-python-1.2.0-py313hde1f3bb_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/bzip2-1.0.8-hd037594_9.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-78.3-hef89b57_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.19.1-hdfa7624_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.1.0-h1eee2c3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libblas-3.11.0-8_h51639a9_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.11.0-8_hb0561ab_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcxx-22.1.7-h55c6f16_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.25-hc11a715_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.5.2-hcf2aa1b_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgcc-15.2.0-hcbb3090_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran-15.2.0-h07b0088_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran5-15.2.0-hdae7583_19.conda @@ -509,7 +504,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libmpdec-4.0.0-h84a0fba_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libopenblas-0.3.33-openmp_he657e61_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.58-h132b30e_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.1-h1b79a29_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.2-h1ae2325_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.7.1-h4030677_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.6.0-h07db88b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.17.0-hdb1d25a_0.conda @@ -518,13 +513,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.6-h1d4f5a5_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.4.6-py313hce9b930_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.4-hd9e9057_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.2-hd24854e_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.3-hd24854e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-12.2.0-py313h45e5a15_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.13-h20e6be0_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.14-h448ec07_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.3-py313h65a2061_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.3-h46df422_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.15-h80928e0_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.17-h80928e0_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h010d191_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.12-hc919400_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxdmcp-1.1.5-hc919400_1.conda @@ -532,12 +527,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/b7/087fcbfe2a0a0b44e236c9853d7fa7c539db6b8c60ab5702fffd73be5a7c/casadi-3.7.2-cp313-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/10/03/8aeeb7458d22546bf64b5250ca1daeb5ff757d900e8e4a7476c6f0db843e/protobuf-7.35.1-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/91/30151a39f7570f448ed84529390628a651d7f27c87d73c9b887f8189695e/docutils-0.23-py3-none-any.whl @@ -547,7 +542,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/42/77/de194443bf38daed9452139e960c632b0ef9f9a5dd9ce605fdf18ca9f1b1/id-1.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl @@ -557,17 +551,16 @@ environments: - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/85/30/d162e99746a2fb1d98bb0ef23af3e201b156cf09f7de867c7390c8fe1c06/nh3-0.3.5-cp38-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/89/8c/182a2a593195bfd39842ea68ebc084e20c850806117213f5a299dfc513d9/uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/96/9a/982e48afcffcd727a9144506720ffd4224b6b7e355c98641866f38b7c043/jaraco_functools-4.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/97/1b/295bf2fa3e740131778065e5ffa2c481f0e7210182d408e9a2c244ff5b0c/readme_renderer-45.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9a/e5/54cb7c50ad5fdc1e0a86b7df4b135c2cbd5c4623605aa94466659098e8da/simplejson-4.1.1-cp313-cp313-macosx_11_0_arm64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b1/88/a29bca408a4c2db6c5bcf58a8b92c464660b7f846c559abd9110783574cb/mujoco-3.9.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -576,11 +569,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/e1/67/921ec3024056483db83953ae8e48079ad62b92db7880013ca77632921dd0/readme_renderer-44.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/3d/1087453384dbde46a8c7f9356eead2c58be8a7bf156bca40243377c85715/more_itertools-11.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl @@ -594,7 +587,7 @@ environments: packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.5.0-py313h18e8e13_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.6.0-py313h18e8e13_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.2.0-py313hf159716_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.19.1-h0c24ade_1.conda @@ -603,7 +596,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-8_h4a7cf45_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -617,7 +610,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.33-pthreads_h94d23a6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.42.1-h5347b49_0.conda @@ -628,13 +621,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.6-hdb14827_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py313hf6604e3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py313h80991f8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.13-h6add32d_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.14-h6add32d_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py313h3dea7bd_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/watchdog-6.0.0-py313hd5f5364_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb03c661_1.conda @@ -681,41 +674,37 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/11/8c/c9138d881c79aa0ea9ed83cbd58d5ca75624378b38cee225dcf5c42cc91f/griffelib-2.0.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/15/c0/0be24758891ef825f2065cd5db8741aaddabe3e248ee6acc5e8a80f04005/uvloop-0.22.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/28/de/a3e710469772c6a89595fc52816da05c1e164b4c866a89e3cb82fb1b67c5/mkdocs_autorefs-1.4.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/31/6f/4015dbb4c26bf1fc4b5b637188fc47ec2f1781baccc2e13b0c48887ae9b0/mkdocs_charts_plugin-0.0.13-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/32/28/79f0f8de97cce916d5ae88a7bee1ad724855e83e6019c0b4d5b3fabc80f3/mkdocstrings_python-2.0.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/2e/21a3ede87f0bf82d6c7bcb90480d50a6490eb974c6ab20881188e440957c/simplejson-4.1.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3f/5b/7120e22f6e22ca77283f4a086ab2e59d107f00bfc952116db41a015385fe/casadi-3.7.2-cp313-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/2c/bcf1ae903975ad6f169abb05c1eb0f94395478364deb89270cf034081b29/mkdocs_literate_nav-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5e/c6/82669e70cef67c803852285ba6f59d7e3d102983c0ab4be8269c14756677/tensorstore-0.1.84-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/5e/e3/00ec594aef5f55522e6d373bc2ac53e53a8f5e9ae32f2d6854b0de4270f3/mkdocstrings_python-2.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6e/94/be70f8ee9c45f2f62b39a1f0e9303bc20e138a8f3b8e50ffd89498e177e1/mkdocstrings-1.0.4-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/df/a1495de78c1da3e8e93978dd177b04d18aaa7361452e30a3467c41c3b19e/mujoco-3.9.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/4d/a330cab5e055d45e924cec69da54a3d8ed37643964f8d1fa1a772b496273/mkdocs_section_index-0.3.12-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -725,10 +714,12 @@ environments: - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ee/1b/3075eb67fe66e19db059f0a25744c4e56978a309603a20e1d3353d545b5e/mkdocs_gen_files-0.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/babel-2.18.0-pyhcf101f3_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/backrefs-7.0-pyhcf101f3_0.conda @@ -769,19 +760,20 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.5.0-py313h7208f8c_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.6.0-py313h7208f8c_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-python-1.2.0-py313hde1f3bb_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/bzip2-1.0.8-hd037594_9.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-78.3-hef89b57_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.19.1-hdfa7624_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.1.0-h1eee2c3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libblas-3.11.0-8_h51639a9_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.11.0-8_hb0561ab_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcxx-22.1.7-h55c6f16_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.25-hc11a715_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.5.2-hcf2aa1b_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgcc-15.2.0-hcbb3090_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran-15.2.0-h07b0088_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran5-15.2.0-hdae7583_19.conda @@ -791,7 +783,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libmpdec-4.0.0-h84a0fba_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libopenblas-0.3.33-openmp_he657e61_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.58-h132b30e_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.1-h1b79a29_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.2-h1ae2325_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.7.1-h4030677_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.6.0-h07db88b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.17.0-hdb1d25a_0.conda @@ -801,13 +793,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.6-h1d4f5a5_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.4.6-py313hce9b930_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.4-hd9e9057_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.2-hd24854e_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.3-hd24854e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-12.2.0-py313h45e5a15_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.13-h20e6be0_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.14-h448ec07_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.3-py313h65a2061_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.3-h46df422_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.15-h80928e0_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.17-h80928e0_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h010d191_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/watchdog-6.0.0-py313h6688731_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.12-hc919400_1.conda @@ -816,39 +808,36 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/b7/087fcbfe2a0a0b44e236c9853d7fa7c539db6b8c60ab5702fffd73be5a7c/casadi-3.7.2-cp313-none-macosx_11_0_arm64.whl + - pypi: https://files.pythonhosted.org/packages/10/03/8aeeb7458d22546bf64b5250ca1daeb5ff757d900e8e4a7476c6f0db843e/protobuf-7.35.1-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/11/8c/c9138d881c79aa0ea9ed83cbd58d5ca75624378b38cee225dcf5c42cc91f/griffelib-2.0.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/28/de/a3e710469772c6a89595fc52816da05c1e164b4c866a89e3cb82fb1b67c5/mkdocs_autorefs-1.4.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/31/6f/4015dbb4c26bf1fc4b5b637188fc47ec2f1781baccc2e13b0c48887ae9b0/mkdocs_charts_plugin-0.0.13-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/32/28/79f0f8de97cce916d5ae88a7bee1ad724855e83e6019c0b4d5b3fabc80f3/mkdocstrings_python-2.0.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/83/4f3c6ef9bed01f384036c2030b3901cf075bbc8eff6e4529e502f0283ab5/tensorstore-0.1.84-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4e/2c/bcf1ae903975ad6f169abb05c1eb0f94395478364deb89270cf034081b29/mkdocs_literate_nav-0.6.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/5e/e3/00ec594aef5f55522e6d373bc2ac53e53a8f5e9ae32f2d6854b0de4270f3/mkdocstrings_python-2.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/6e/94/be70f8ee9c45f2f62b39a1f0e9303bc20e138a8f3b8e50ffd89498e177e1/mkdocstrings-1.0.4-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/89/8c/182a2a593195bfd39842ea68ebc084e20c850806117213f5a299dfc513d9/uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/9a/e5/54cb7c50ad5fdc1e0a86b7df4b135c2cbd5c4623605aa94466659098e8da/simplejson-4.1.1-cp313-cp313-macosx_11_0_arm64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b0/4d/a330cab5e055d45e924cec69da54a3d8ed37643964f8d1fa1a772b496273/mkdocs_section_index-0.3.12-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b1/88/a29bca408a4c2db6c5bcf58a8b92c464660b7f846c559abd9110783574cb/mujoco-3.9.0-cp313-cp313-macosx_11_0_arm64.whl @@ -857,6 +846,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl @@ -879,7 +869,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-8_h4a7cf45_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -894,7 +884,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.1-hb9d3cd8_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.33-pthreads_h94d23a6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.42.1-h5347b49_0.conda @@ -905,12 +895,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.6-hdb14827_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py312h33ff503_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py312h50c33e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.13-hd63d673_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.5-hb03c661_1.conda @@ -926,8 +916,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.47.0-pyhd8ed1ab_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/b6/156a8de1e1b47694f0e7de6675866936608d45dc68388fd017d36f8693be/simplejson-4.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl @@ -935,30 +923,30 @@ environments: - pypi: https://files.pythonhosted.org/packages/18/2a/d4cd8506d2044e082f8cd921be57392e6a9b5ccd3ffdf050362430a3d5d5/nvidia_cuda_cccl_cu12-12.9.27-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/7c/b7b24e10e5cb0213c85204d53fcd60d0568d986ea0001a00a815e14e01e1/tensorstore-0.1.84-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/ab/d7233c915b12c005655437c6c4cf0ae46cbbb2b20d743cb5e4881ad3104a/casadi-3.7.2-cp312-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/25/48/b54a06168a2190572a312bfe4ce443687773eb61367ced31e064953dd2f7/nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/33/40/79b0c64d44d6c166c0964ec1d803d067f4a145cca23e23925fd351d0e642/nvidia_cusolver_cu12-11.7.5.82-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/38/69/2912ab63036e21c72748019e1d8e09e8a1fc3368b3e83fc27898a1858575/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/38/ed/b7728573156d70b6b094233b0f38d876fc37340826cf852347ec2c7ca8ca/msgpack-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3a/13/547360d81e6d88d58492968ffda9f9542854f11310ee556fef14260cc886/zipp-4.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/3e/85/1c12e849e4d50624e75496378a3fb168389f768d3ec7cb694fba873ff9a8/nvidia_nvshmem_cu12-3.7.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/50/32/e7ffa9c324ae260e5dbb4af2cd557bf7a8d155c8ac7b79a785fe1796fb92/nvidia_nccl_cu12-2.30.7-py3-none-manylinux_2_18_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5f/6f/e62b4dfc7ad6518e7eff2516f680d02a0f6eb62c0c212e152ca708a0085e/uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/6b/c3/0e45ff4dce8401f6ea7c25d80d75738813a47f5ae2691e2478f2fd1e5e93/nvidia_nccl_cu12-2.30.4-py3-none-manylinux_2_18_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/7d/9d/1a383211b0967e702b9e84643986fb31bf35ca07bddc19e0cf139fd3291d/nvidia_cudnn_cu12-9.23.0.39-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/3f/cabd3c791ff5042df157609e00e96440ccaba69f72bccd8e3470d85fdd48/jax_cuda12_plugin-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl @@ -966,8 +954,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/95/f4/61e6996dd20481ee834f57a8e9dca28b1869366a135e0d42e2aa8493bdd4/nvidia_cufft_cu12-11.4.1.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/9e/da/36fa8307cc40889307fed415d70b67d35ec330ffce889a9c03cf8f616cfa/nvidia_nvshmem_cu12-3.6.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/a0/7e/e0a5d44bf070a1ff945050abc02ef1cff5ca9c6ab5dc6a16ab6322593a32/nvidia_cudnn_cu12-9.23.1.3-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl @@ -981,6 +968,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/d5/01/b2a88b6b73df933d5ab38583240c296684b626a8de3c3bb9a7c2fd356f08/mujoco-3.9.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f7/a1/47c08a81760cae84c4a4aa720f3fc1ce3bac6f7aafa5ab82c302d7946f07/jax_cuda12_pjrt-0.10.1-py3-none-manylinux_2_27_x86_64.whl gpu-tests: @@ -991,7 +979,7 @@ environments: packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.16-hb03c661_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.16.1-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-1.2.0-hed03a55_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda @@ -1017,14 +1005,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-22.1.6-default_h746c552_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-22.1.7-default_h746c552_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h7a8fb5f_6.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdrm-2.4.127-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20250104-pl5321h7949ede_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libegl-1.7.0-ha4b6fd6_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libegl-devel-1.7.0-ha4b6fd6_3.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -1051,7 +1039,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libpciaccess-0.19-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpq-18.4-hd5a49e9_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-15.2.0-hdf11a46_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda @@ -1071,7 +1059,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py312h33ff503_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openldap-2.6.13-hbde042b_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py312h50c33e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda @@ -1079,11 +1067,11 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pyside6-6.11.1-py312h50ac2ff_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.13-hd63d673_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/qhull-2020.2-h434a139_5.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.11.1-pl5321h16c4a6b_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.11.1-pl5321h16c4a6b_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.6-py312h4c3975b_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.7-py312h4c3975b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-17.0.1-py312h4c3975b_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/wayland-1.25.0-hd6090a7_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.1-h4f16b4b_2.conda @@ -1092,7 +1080,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.1-hb711507_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.10-hb711507_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.2-hb711507_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.47-hb03c661_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.47-h280c20c_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.2-hb9d3cd8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.6-he73a12e_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.13-he1eb515_0.conda @@ -1140,8 +1128,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/wheel-0.47.0-pyhd8ed1ab_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/b6/156a8de1e1b47694f0e7de6675866936608d45dc68388fd017d36f8693be/simplejson-4.1.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl @@ -1150,30 +1136,30 @@ environments: - pypi: https://files.pythonhosted.org/packages/18/2a/d4cd8506d2044e082f8cd921be57392e6a9b5ccd3ffdf050362430a3d5d5/nvidia_cuda_cccl_cu12-12.9.27-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/7c/b7b24e10e5cb0213c85204d53fcd60d0568d986ea0001a00a815e14e01e1/tensorstore-0.1.84-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/24/ab/d7233c915b12c005655437c6c4cf0ae46cbbb2b20d743cb5e4881ad3104a/casadi-3.7.2-cp312-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/25/48/b54a06168a2190572a312bfe4ce443687773eb61367ced31e064953dd2f7/nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/33/40/79b0c64d44d6c166c0964ec1d803d067f4a145cca23e23925fd351d0e642/nvidia_cusolver_cu12-11.7.5.82-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/38/69/2912ab63036e21c72748019e1d8e09e8a1fc3368b3e83fc27898a1858575/jaxlib-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/38/ed/b7728573156d70b6b094233b0f38d876fc37340826cf852347ec2c7ca8ca/msgpack-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3a/13/547360d81e6d88d58492968ffda9f9542854f11310ee556fef14260cc886/zipp-4.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/3e/85/1c12e849e4d50624e75496378a3fb168389f768d3ec7cb694fba873ff9a8/nvidia_nvshmem_cu12-3.7.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/46/0c/c75bbfb967457a0b7670b8ad267bfc4fffdf341c074e0a80db06c24ccfd4/nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/50/32/e7ffa9c324ae260e5dbb4af2cd557bf7a8d155c8ac7b79a785fe1796fb92/nvidia_nccl_cu12-2.30.7-py3-none-manylinux_2_18_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5f/6f/e62b4dfc7ad6518e7eff2516f680d02a0f6eb62c0c212e152ca708a0085e/uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/6b/c3/0e45ff4dce8401f6ea7c25d80d75738813a47f5ae2691e2478f2fd1e5e93/nvidia_nccl_cu12-2.30.4-py3-none-manylinux_2_18_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/7d/9d/1a383211b0967e702b9e84643986fb31bf35ca07bddc19e0cf139fd3291d/nvidia_cudnn_cu12-9.23.0.39-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/3f/cabd3c791ff5042df157609e00e96440ccaba69f72bccd8e3470d85fdd48/jax_cuda12_plugin-0.10.1-cp312-cp312-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl @@ -1181,8 +1167,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/95/f4/61e6996dd20481ee834f57a8e9dca28b1869366a135e0d42e2aa8493bdd4/nvidia_cufft_cu12-11.4.1.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/9e/da/36fa8307cc40889307fed415d70b67d35ec330ffce889a9c03cf8f616cfa/nvidia_nvshmem_cu12-3.6.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/a0/7e/e0a5d44bf070a1ff945050abc02ef1cff5ca9c6ab5dc6a16ab6322593a32/nvidia_cudnn_cu12-9.23.1.3-py3-none-manylinux_2_27_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl @@ -1196,6 +1181,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/d5/01/b2a88b6b73df933d5ab38583240c296684b626a8de3c3bb9a7c2fd356f08/mujoco-3.9.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f7/a1/47c08a81760cae84c4a4aa720f3fc1ce3bac6f7aafa5ab82c302d7946f07/jax_cuda12_pjrt-0.10.1-py3-none-manylinux_2_27_x86_64.whl release: @@ -1206,7 +1192,7 @@ environments: packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.5.0-py313h18e8e13_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.6.0-py313h18e8e13_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-python-1.2.0-py313hf159716_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/lcms2-2.19.1-h0c24ade_1.conda @@ -1215,7 +1201,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libblas-3.11.0-8_h4a7cf45_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -1229,7 +1215,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libmpdec-4.0.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libopenblas-0.3.33-pthreads_h94d23a6_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libuuid-2.42.1-h5347b49_0.conda @@ -1239,13 +1225,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.6-hdb14827_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py313hf6604e3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py313h80991f8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.13-h6add32d_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.14-h6add32d_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py313h3dea7bd_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxau-1.0.12-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libxdmcp-1.1.5-hb03c661_1.conda @@ -1270,8 +1256,6 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl @@ -1279,6 +1263,8 @@ environments: - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/1b/0e/bf298920729f216adcb002acf7ea01b90842603d2e4e2ce9b900d9ee8fab/nh3-0.3.5-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/20/2c/0622f20ff02b2ef32558733443805dc82fd4c275be01b2d19d14676f3a1b/cryptography-49.0.0-cp311-abi3-manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/91/30151a39f7570f448ed84529390628a651d7f27c87d73c9b887f8189695e/docutils-0.23-py3-none-any.whl @@ -1289,28 +1275,25 @@ environments: - pypi: https://files.pythonhosted.org/packages/42/77/de194443bf38daed9452139e960c632b0ef9f9a5dd9ce605fdf18ca9f1b1/id-1.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5e/c6/82669e70cef67c803852285ba6f59d7e3d102983c0ab4be8269c14756677/tensorstore-0.1.84-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/7f/66/b15ce62552d84bbfcec9a4873ab79d993a1dd4edb922cbfccae192bd5b5f/jaraco.classes-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/95/38/0d29a6fd7d0d1373f0c0c88a04ba20e359b257753ac497564cd660fc1d55/cryptography-48.0.0-cp311-abi3-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/96/9a/982e48afcffcd727a9144506720ffd4224b6b7e355c98641866f38b7c043/jaraco_functools-4.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/97/1b/295bf2fa3e740131778065e5ffa2c481f0e7210182d408e9a2c244ff5b0c/readme_renderer-45.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/df/a1495de78c1da3e8e93978dd177b04d18aaa7361452e30a3467c41c3b19e/mujoco-3.9.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/a3/e137168c9c44d18eff0376253da9f1e9234d0239e0ee230d2fee6cea8e55/jeepney-0.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -1322,12 +1305,13 @@ environments: - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/e1/67/921ec3024056483db83953ae8e48079ad62b92db7880013ca77632921dd0/readme_renderer-44.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e8/3d/1087453384dbde46a8c7f9356eead2c58be8a7bf156bca40243377c85715/more_itertools-11.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/f2/58/bc8954bda5fcda97bd7c19be11b85f91973d67a706ed4a3aec33e7de22db/jaraco_context-6.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/ff/9a/9afaade874b2fa6c752c36f1548f718b5b83af81ed9b76628329dab81c1b/rfc3986-2.0.0-py2.py3-none-any.whl osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.5.20-hbd8a1cb_0.conda @@ -1348,19 +1332,20 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/urllib3-2.7.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/_openmp_mutex-4.5-7_kmp_llvm.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.5.0-py313h7208f8c_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.6.0-py313h7208f8c_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-python-1.2.0-py313hde1f3bb_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/bzip2-1.0.8-hd037594_9.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-78.3-hef89b57_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.19.1-hdfa7624_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.1.0-h1eee2c3_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libblas-3.11.0-8_h51639a9_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.11.0-8_hb0561ab_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcxx-22.1.7-h55c6f16_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.25-hc11a715_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.5.2-hcf2aa1b_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgcc-15.2.0-hcbb3090_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran-15.2.0-h07b0088_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran5-15.2.0-hdae7583_19.conda @@ -1370,7 +1355,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libmpdec-4.0.0-h84a0fba_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libopenblas-0.3.33-openmp_he657e61_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.58-h132b30e_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.1-h1b79a29_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.2-h1ae2325_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.7.1-h4030677_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.6.0-h07db88b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.17.0-hdb1d25a_0.conda @@ -1379,13 +1364,13 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.6-h1d4f5a5_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.4.6-py313hce9b930_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.4-hd9e9057_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.2-hd24854e_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.3-hd24854e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-12.2.0-py313h45e5a15_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.13-h20e6be0_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.14-h448ec07_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.3-py313h65a2061_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.3-h46df422_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.15-h80928e0_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.17-h80928e0_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h010d191_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.12-hc919400_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxdmcp-1.1.5-hc919400_1.conda @@ -1393,12 +1378,12 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/b7/087fcbfe2a0a0b44e236c9853d7fa7c539db6b8c60ab5702fffd73be5a7c/casadi-3.7.2-cp313-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/0d/fe/6bea5c9162869c5beba5d9c8abbed835ec85bf1ec1fba05a3822325c45f3/build-1.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/10/03/8aeeb7458d22546bf64b5250ca1daeb5ff757d900e8e4a7476c6f0db843e/protobuf-7.35.1-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/32/91/30151a39f7570f448ed84529390628a651d7f27c87d73c9b887f8189695e/docutils-0.23-py3-none-any.whl @@ -1408,7 +1393,6 @@ environments: - pypi: https://files.pythonhosted.org/packages/42/77/de194443bf38daed9452139e960c632b0ef9f9a5dd9ce605fdf18ca9f1b1/id-1.6.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl @@ -1418,17 +1402,16 @@ environments: - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/85/30/d162e99746a2fb1d98bb0ef23af3e201b156cf09f7de867c7390c8fe1c06/nh3-0.3.5-cp38-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/89/8c/182a2a593195bfd39842ea68ebc084e20c850806117213f5a299dfc513d9/uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/96/9a/982e48afcffcd727a9144506720ffd4224b6b7e355c98641866f38b7c043/jaraco_functools-4.5.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/97/1b/295bf2fa3e740131778065e5ffa2c481f0e7210182d408e9a2c244ff5b0c/readme_renderer-45.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/9a/e5/54cb7c50ad5fdc1e0a86b7df4b135c2cbd5c4623605aa94466659098e8da/simplejson-4.1.1-cp313-cp313-macosx_11_0_arm64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b1/88/a29bca408a4c2db6c5bcf58a8b92c464660b7f846c559abd9110783574cb/mujoco-3.9.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -1437,11 +1420,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/bd/24/12818598c362d7f300f18e74db45963dbcb85150324092410c8b49405e42/pyproject_hooks-1.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/e1/67/921ec3024056483db83953ae8e48079ad62b92db7880013ca77632921dd0/readme_renderer-44.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e8/3d/1087453384dbde46a8c7f9356eead2c58be8a7bf156bca40243377c85715/more_itertools-11.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl @@ -1455,7 +1438,7 @@ environments: packages: linux-64: - conda: https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-20_gnu.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.16-hb03c661_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.16.1-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-1.2.0-hed03a55_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-bin-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/bzip2-1.0.8-hda65f42_9.conda @@ -1481,14 +1464,14 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlidec-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libbrotlienc-1.2.0-hb03c661_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcblas-3.11.0-8_h0358290_openblas.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-22.1.6-default_h746c552_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-22.1.7-default_h746c552_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h7a8fb5f_6.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdeflate-1.25-h17f619e_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libdrm-2.4.127-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libedit-3.1.20250104-pl5321h7949ede_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libegl-1.7.0-ha4b6fd6_3.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libegl-devel-1.7.0-ha4b6fd6_3.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype-2.14.3-ha770c72_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libfreetype6-2.14.3-h73754d4_0.conda @@ -1515,7 +1498,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libpciaccess-0.19-hb03c661_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpng-1.6.58-h421ea60_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libpq-18.4-hd5a49e9_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-15.2.0-hdf11a46_19.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libtiff-4.7.1-h9d88235_1.conda @@ -1535,20 +1518,20 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py313hf6604e3_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openldap-2.6.13-hbde042b_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pillow-12.2.0-py313h80991f8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pixman-0.46.4-h54a6638_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pthread-stubs-0.4-hb9d3cd8_1002.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyside6-6.11.1-py313hcd51b16_1.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.13-h6add32d_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.14-h6add32d_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py313h3dea7bd_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/qhull-2020.2-h434a139_5.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.11.1-pl5321h16c4a6b_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.11.1-pl5321h16c4a6b_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.6-py313h07c4f96_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.7-py313h07c4f96_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/wayland-1.25.0-hd6090a7_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-0.4.1-h4f16b4b_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-cursor-0.1.6-hb03c661_0.conda @@ -1556,7 +1539,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-keysyms-0.4.1-hb711507_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-renderutil-0.3.10-hb711507_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xcb-util-wm-0.4.2-hb711507_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.47-hb03c661_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.47-h280c20c_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.2-hb9d3cd8_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libsm-1.2.6-he73a12e_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libx11-1.8.13-he1eb515_0.conda @@ -1604,36 +1587,32 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-hc9c84f9_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/zipp-4.1.0-pyhcf101f3_0.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/0c/b9/c3df11997d29e69b3f8edae1e903bf44eaf4774ccf4c5b6ddcebde88931c/pytest_markdown_docs-0.9.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/15/c0/0be24758891ef825f2065cd5db8741aaddabe3e248ee6acc5e8a80f04005/uvloop-0.22.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/1d/69a0ba52fb546261e71a7209378ee6059950e9c088a2a18355e01509f474/jaxlib-0.10.1-cp313-cp313-manylinux_2_27_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/38/2e/21a3ede87f0bf82d6c7bcb90480d50a6490eb974c6ab20881188e440957c/simplejson-4.1.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl - pypi: https://files.pythonhosted.org/packages/3f/5b/7120e22f6e22ca77283f4a086ab2e59d107f00bfc952116db41a015385fe/casadi-3.7.2-cp313-none-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3f/efeb7c6801c46e11bd666a5180f0d615f74f72264212f74f39586c6fda9d/glfw-2.10.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.py39.py310.py311.py312.py313.py314-none-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5c/6e/5087e0347188f6970aba1ffbd0018754d23c3f3461e9f21785f2f27a02c2/jax-0.10.1-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5e/c6/82669e70cef67c803852285ba6f59d7e3d102983c0ab4be8269c14756677/tensorstore-0.1.84-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/97/df/a1495de78c1da3e8e93978dd177b04d18aaa7361452e30a3467c41c3b19e/mujoco-3.9.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl @@ -1642,9 +1621,11 @@ environments: - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/de/e4/1ba6f44e491c4eece978685230dde56b14d51a0365bc1b774ddaa94d14cd/pyopengl-3.1.10-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/e9/73/fda6a25f3beeb5e49d74330b44092b9e5a547395ccd478d1103ddcbff1fc/gymnasium-1.3.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/eb/33/40cd74219417e78b97c47802037cf2d87b91973e18bb968a7da48a96ea44/ml_dtypes-0.5.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl + - pypi: https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl osx-arm64: - conda: https://conda.anaconda.org/conda-forge/noarch/array-api-strict-2.5-pyhcf101f3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/ca-certificates-2026.5.20-hbd8a1cb_0.conda @@ -1673,7 +1654,8 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/bzip2-1.0.8-hd037594_9.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/contourpy-1.3.3-py313h2af2deb_4.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/fonttools-4.63.0-py313h65a2061_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/freetype-2.14.3-hce30654_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/freetype-2.14.3-hce30654_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-78.3-hef89b57_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/kiwisolver-1.5.0-py313h2af2deb_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lcms2-2.19.1-hdfa7624_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/lerc-4.1.0-h1eee2c3_0.conda @@ -1684,10 +1666,10 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcblas-3.11.0-8_hb0561ab_openblas.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libcxx-22.1.7-h55c6f16_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libdeflate-1.25-hc11a715_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.5.2-hcf2aa1b_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgcc-15.2.0-hcbb3090_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran-15.2.0-h07b0088_19.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgfortran5-15.2.0-hdae7583_19.conda @@ -1697,7 +1679,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libmpdec-4.0.0-h84a0fba_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libopenblas-0.3.33-openmp_he657e61_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libpng-1.6.58-h132b30e_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.1-h1b79a29_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.2-h1ae2325_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.7.1-h4030677_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libwebp-base-1.6.0-h07db88b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libxcb-1.17.0-hdb1d25a_0.conda @@ -1708,34 +1690,33 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.6-h1d4f5a5_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.4.6-py313hce9b930_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.4-hd9e9057_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.2-hd24854e_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.3-hd24854e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-12.2.0-py313h45e5a15_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pthread-stubs-0.4-hd74edd7_1002.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.13-h20e6be0_100_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.14-h448ec07_100_cp313.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.3-py313h65a2061_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/qhull-2020.2-h420ef59_5.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/readline-8.3-h46df422_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.15-h80928e0_1.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.17-h80928e0_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h010d191_3.conda - - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.5.6-py313h0997733_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.5.7-py313h0997733_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxau-1.0.12-hc919400_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/xorg-libxdmcp-1.1.5-hc919400_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/yaml-0.2.5-h925e9cb_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zlib-ng-2.3.3-hed4e4f5_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/zstd-1.5.7-hbf9d68e_6.conda - pypi: ./ - - pypi: ./submodules/drone-controllers - - pypi: ./submodules/drone-models - pypi: https://files.pythonhosted.org/packages/05/98/716a473cfb24750858ddd5d14e6527539dd206583a46408d08eeb2844a75/trimesh-4.12.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/09/b7/087fcbfe2a0a0b44e236c9853d7fa7c539db6b8c60ab5702fffd73be5a7c/casadi-3.7.2-cp313-none-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/0c/b9/c3df11997d29e69b3f8edae1e903bf44eaf4774ccf4c5b6ddcebde88931c/pytest_markdown_docs-0.9.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/10/03/8aeeb7458d22546bf64b5250ca1daeb5ff757d900e8e4a7476c6f0db843e/protobuf-7.35.1-cp310-abi3-macosx_10_9_universal2.whl - pypi: https://files.pythonhosted.org/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/41/83/4f3c6ef9bed01f384036c2030b3901cf075bbc8eff6e4529e502f0283ab5/tensorstore-0.1.84-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/43/2b/36e984399089c026a6499ac8f7401d38487cf0183839a4aa78140d373771/treescope-0.1.10-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/4c/03/5b668e78eff52a459c707e442a3cbd3e0f8b74d08a4b92111a07159aff11/mujoco_mjx-3.9.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl @@ -1743,15 +1724,13 @@ environments: - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl + - pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/89/8c/182a2a593195bfd39842ea68ebc084e20c850806117213f5a299dfc513d9/uvloop-0.22.1-cp313-cp313-macosx_10_13_universal2.whl - pypi: https://files.pythonhosted.org/packages/8a/69/6a93d8600c339d7687a05857c7907bd4dd8cf88691a5ea106d7a50af90a1/optax-0.2.8-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8d/9b/d4b1e644385499c8346fa9b622a3f030dce14cd6ef8a1871c221a17a67e7/prometheus_client-0.25.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8e/65/4bd2abfd4cb6e917b2626de5cbfc034dfc94b74dd95b8272d93f2ad66bed/flax-0.12.7-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/9a/e5/54cb7c50ad5fdc1e0a86b7df4b135c2cbd5c4623605aa94466659098e8da/simplejson-4.1.1-cp313-cp313-macosx_11_0_arm64.whl - - pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b1/88/a29bca408a4c2db6c5bcf58a8b92c464660b7f846c559abd9110783574cb/mujoco-3.9.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl @@ -1759,6 +1738,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/bc/8a/340a1555ae33d7354dbca4faa54948d76d89a27ceef032c8c3bc661d003e/aiofiles-25.1.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c5/7b/bca5613a0c3b542420cf92bd5e5fb8ebd5435ce1011a091f66bb7693285e/humanize-4.15.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/c7/d1/63b5014a6184210292c66944f051e9fc95c0272fe5464d1b1a2de5de0104/orbax_checkpoint-0.12.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/cf/76/3b637d4def229015a3035a7b44fac0dcf2536ae337540cdbffc651334d4e/jaxlib-0.10.1-cp313-cp313-macosx_11_0_arm64.whl - pypi: https://files.pythonhosted.org/packages/d5/0c/043d5e551459da400957a1395e0febbf771446ff34291afcbe3d8be2a279/fsspec-2026.4.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d9/a1/4008f14bbc616cfb1ac5b39ea485f9c63031c4634ab3f4cf72e7541f816a/ml_dtypes-0.5.4-cp313-cp313-macosx_10_13_universal2.whl @@ -1780,31 +1760,31 @@ packages: purls: [] size: 28948 timestamp: 1770939786096 -- conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.16-hb03c661_0.conda - sha256: 64484cb7e7cf65d9c7188498d595125a6e0751604dd7fb483e1bb327eec0f738 - md5: 18d273b22e96c97c2017813f099faa82 +- conda: https://conda.anaconda.org/conda-forge/linux-64/alsa-lib-1.2.16.1-hb03c661_0.conda + sha256: cf93ca0f1f107e95a35969a4622684e08fcb8cf37f8cf4a1e9e424828386c921 + md5: 8904e09bda369377b3dd07e2ac828c5d depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 license: LGPL-2.1-or-later - license_family: GPL + license_family: LGPL purls: [] - size: 591046 - timestamp: 1780398678742 -- conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.5.0-py313h18e8e13_0.conda - sha256: 310e114a783b249517d1dd6e74b3f339af30e947bc93446ae4e4e9c86fff7478 - md5: 0de0c2c1f2677ea074bdda91de5a4c01 + size: 592377 + timestamp: 1781521980743 +- conda: https://conda.anaconda.org/conda-forge/linux-64/backports.zstd-1.6.0-py313h18e8e13_0.conda + sha256: d7aca2b34335035ef80eaaebff4e5a0ae524b6f3792a82a144d38c4a81f42916 + md5: fbc7d3707f09cae6648e6eb1b6203a5c depends: - python - - libgcc >=14 - __glibc >=2.17,<3.0.a0 + - libgcc >=14 - python_abi 3.13.* *_cp313 - zstd >=1.5.7,<1.6.0a0 license: BSD-3-Clause AND MIT AND EPL-2.0 purls: - - pkg:pypi/backports-zstd?source=hash-mapping - size: 242514 - timestamp: 1778594045042 + - pkg:pypi/backports-zstd?source=compressed-mapping + size: 242845 + timestamp: 1781450812380 - conda: https://conda.anaconda.org/conda-forge/linux-64/brotli-1.2.0-hed03a55_1.conda sha256: e511644d691f05eb12ebe1e971fd6dc3ae55a4df5c253b4e1788b789bdf2dfa6 md5: 8ccf913aaba749a5496c17629d859ed1 @@ -1973,6 +1953,7 @@ packages: - libuuid >=2.42.1,<3.0a0 - libzlib >=1.3.2,<2.0a0 license: MIT + license_family: MIT purls: [] size: 281880 timestamp: 1780450077431 @@ -2027,6 +2008,7 @@ packages: - libgcc >=14 - libstdcxx >=14 license: LGPL-2.0-or-later + license_family: LGPL purls: [] size: 100054 timestamp: 1780454302233 @@ -2046,6 +2028,7 @@ packages: - libstdcxx >=14 - libzlib >=1.3.2,<2.0a0 license: MIT + license_family: MIT purls: [] size: 2362258 timestamp: 1780450503234 @@ -2223,19 +2206,19 @@ packages: purls: [] size: 18778 timestamp: 1779859107964 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-22.1.6-default_h746c552_1.conda - sha256: 4f91ada190f6e78efd9179fd995c9c7fe2f4bb00aef977a164b1cba8d49973bb - md5: bf306e7b1c8c2c204b28138a08666bbd +- conda: https://conda.anaconda.org/conda-forge/linux-64/libclang13-22.1.7-default_h746c552_1.conda + sha256: 5100d6571c361a3b4123007b71448a15901ad63ac948f3f02bbc7df4079fe4d1 + md5: f5d04d68e7fd19a24f1fe35a74bafabb depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 - - libllvm22 >=22.1.6,<22.2.0a0 + - libllvm22 >=22.1.7,<22.2.0a0 - libstdcxx >=14 license: Apache-2.0 WITH LLVM-exception license_family: Apache purls: [] - size: 12828360 - timestamp: 1779397396725 + size: 12818349 + timestamp: 1780522452233 - conda: https://conda.anaconda.org/conda-forge/linux-64/libcups-2.3.3-h7a8fb5f_6.conda sha256: 205c4f19550f3647832ec44e35e6d93c8c206782bdd620c1d7cf66237580ff9c md5: 49c553b47ff679a6a1e9fc80b9c5a2d4 @@ -2308,9 +2291,9 @@ packages: purls: [] size: 31718 timestamp: 1779728222280 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_0.conda - sha256: 363018b25fdb5534c79783d912bd4b685a3547f4fc5996357ad548899b0ee8e7 - md5: 93764a5ca80616e9c10106cdaec92f74 +- conda: https://conda.anaconda.org/conda-forge/linux-64/libexpat-2.8.1-hecca717_1.conda + sha256: 16feffd9ddbbe5b718515d38ee376c685ba95491cd901244e24671d20b952a77 + md5: b24d3c612f71e7aa74158d92106318b2 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 @@ -2319,8 +2302,8 @@ packages: license: MIT license_family: MIT purls: [] - size: 77294 - timestamp: 1779278686680 + size: 77856 + timestamp: 1781203599810 - conda: https://conda.anaconda.org/conda-forge/linux-64/libffi-3.5.2-h3435931_0.conda sha256: 31f19b6a88ce40ebc0d5a992c131f57d919f73c0b92cd1617a5bec83f6e961e6 md5: a360c33a5abe61c07959e449fa1453eb @@ -2642,17 +2625,17 @@ packages: purls: [] size: 2754709 timestamp: 1778786234149 -- conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.1-h0c1763c_0.conda - sha256: 54cdcd3214313b62c2a8ee277e6f42150d9b748264c1b70d958bf735e420ef8d - md5: 7dc38adcbf71e6b38748e919e16e0dce +- conda: https://conda.anaconda.org/conda-forge/linux-64/libsqlite-3.53.2-h0c1763c_0.conda + sha256: 1ab603b6ec93933e76027e1f23b21b22b858ba1b56f1e1695ef6fe5e80cb7358 + md5: 062b0ac602fb0adf250e3dfa86f221c4 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 - libzlib >=1.3.2,<2.0a0 license: blessing purls: [] - size: 954962 - timestamp: 1777986471789 + size: 957849 + timestamp: 1780574429573 - conda: https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-15.2.0-h934c35e_19.conda sha256: dff1058c76ec6b8759e41cefa2508162d00e4a5e6721aa68ec3fd10094e702dc md5: 5794b3bdc38177caf969dabd3af08549 @@ -2963,7 +2946,7 @@ packages: license: BSD-3-Clause license_family: BSD purls: - - pkg:pypi/numpy?source=compressed-mapping + - pkg:pypi/numpy?source=hash-mapping size: 8759520 timestamp: 1779169200325 - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.4.6-py313hf6604e3_0.conda @@ -2983,7 +2966,7 @@ packages: license: BSD-3-Clause license_family: BSD purls: - - pkg:pypi/numpy?source=compressed-mapping + - pkg:pypi/numpy?source=hash-mapping size: 8864096 timestamp: 1779169199037 - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.4-h55fea9a_0.conda @@ -3016,9 +2999,9 @@ packages: purls: [] size: 786149 timestamp: 1775741359582 -- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.2-h35e630c_0.conda - sha256: c0ef482280e38c71a08ad6d71448194b719630345b0c9c60744a2010e8a8e0cb - md5: da1b85b6a87e141f5140bb9924cecab0 +- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.6.3-h35e630c_0.conda + sha256: d48f5c22b9897c01e4dff3680f1f57ceb02711ab9c62f74339b080419dfad34b + md5: 79dd2074b5cd5c5c6b2930514a11e22d depends: - __glibc >=2.17,<3.0.a0 - ca-certificates @@ -3026,8 +3009,8 @@ packages: license: Apache-2.0 license_family: Apache purls: [] - size: 3167099 - timestamp: 1775587756857 + size: 3159683 + timestamp: 1781069855778 - conda: https://conda.anaconda.org/conda-forge/linux-64/pcre2-10.47-haa7fec5_0.conda sha256: 5e6f7d161356fefd981948bea5139c5aa0436767751a6930cb1ca801ebb113ff md5: 7a3bff861a6583f1889021facefc08b1 @@ -3190,32 +3173,32 @@ packages: purls: [] size: 31608571 timestamp: 1772730708989 -- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.13-h6add32d_100_cp313.conda +- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.13.14-h6add32d_100_cp313.conda build_number: 100 - sha256: 7f77eb57648f545c1f58e10035d0d9d66b0a0efb7c4b58d3ed89ec7269afdde1 - md5: 05051be49267378d2fcd12931e319ac3 + sha256: f2146aff59ce4b571a8f1d1acf94f9bed6cc18ab5632d7dcc940fb48ecdeef99 + md5: 93762cd272814a142cf21d794f8fb0c1 depends: - __glibc >=2.17,<3.0.a0 - bzip2 >=1.0.8,<2.0a0 - ld_impl_linux-64 >=2.36.1 - - libexpat >=2.7.5,<3.0a0 + - libexpat >=2.8.1,<3.0a0 - libffi >=3.5.2,<3.6.0a0 - libgcc >=14 - - liblzma >=5.8.2,<6.0a0 + - liblzma >=5.8.3,<6.0a0 - libmpdec >=4.0.0,<5.0a0 - - libsqlite >=3.52.0,<4.0a0 - - libuuid >=2.42,<3.0a0 + - libsqlite >=3.53.2,<4.0a0 + - libuuid >=2.42.1,<3.0a0 - libzlib >=1.3.2,<2.0a0 - - ncurses >=6.5,<7.0a0 - - openssl >=3.5.6,<4.0a0 + - ncurses >=6.6,<7.0a0 + - openssl >=3.5.7,<4.0a0 - python_abi 3.13.* *_cp313 - readline >=8.3,<9.0a0 - tk >=8.6.13,<8.7.0a0 - tzdata license: Python-2.0 purls: [] - size: 37358322 - timestamp: 1775614712638 + size: 37398694 + timestamp: 1781258934574 python_site_packages_path: lib/python3.13/site-packages - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.3-py313h3dea7bd_1.conda sha256: ef7df29b38ef04ec67a8888a4aa039973eaa377e8c4b59a7be0a1c50cd7e4ac6 @@ -3243,9 +3226,9 @@ packages: purls: [] size: 552937 timestamp: 1720813982144 -- conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.11.1-pl5321h16c4a6b_0.conda - sha256: 787d9a8eb7bb993e4a543901b8edade35c1c8e75d67cadb65c56a8f9c38119a5 - md5: cdae26862f9e4c674b8443fd267f2401 +- conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.11.1-pl5321h16c4a6b_1.conda + sha256: aefbc43bde188ff4027d480da99c7fa9e8e6341e9762e065190239cb9b99bb1c + md5: 331d660aef48fec733a878dd1f8f4206 depends: - libxcb - xcb-util @@ -3256,66 +3239,66 @@ packages: - xcb-util-cursor - libgl-devel - libegl-devel + - libgcc >=14 - __glibc >=2.17,<3.0.a0 - libstdcxx >=14 - - libgcc >=14 - - libjpeg-turbo >=3.1.4.1,<4.0a0 - - alsa-lib >=1.2.15.3,<1.3.0a0 - - xcb-util-renderutil >=0.3.10,<0.4.0a0 - - xorg-libxcursor >=1.2.3,<2.0a0 - - double-conversion >=3.4.0,<3.5.0a0 - - dbus >=1.16.2,<2.0a0 - - libglib >=2.88.1,<3.0a0 - - libsqlite >=3.53.1,<4.0a0 + - xcb-util >=0.4.1,<0.5.0a0 - xorg-libx11 >=1.8.13,<2.0a0 - - krb5 >=1.22.2,<1.23.0a0 - - libdrm >=2.4.125,<2.5.0a0 - - xorg-libxext >=1.3.7,<2.0a0 - - libwebp-base >=1.6.0,<2.0a0 - - harfbuzz >=14.2.0 - - xorg-libice >=1.1.2,<2.0a0 - - xorg-libxdamage >=1.1.6,<2.0a0 + - pcre2 >=10.47,<10.48.0a0 - libbrotlicommon >=1.2.0,<1.3.0a0 - libbrotlienc >=1.2.0,<1.3.0a0 - libbrotlidec >=1.2.0,<1.3.0a0 + - fontconfig >=2.18.1,<3.0a0 + - fonts-conda-ecosystem + - xorg-libxxf86vm >=1.1.7,<2.0a0 - xorg-libxrandr >=1.5.5,<2.0a0 - - xcb-util-image >=0.4.0,<0.5.0a0 + - libsqlite >=3.53.2,<4.0a0 + - libpq >=18.4,<19.0a0 + - xorg-libice >=1.1.2,<2.0a0 + - libtiff >=4.7.1,<4.8.0a0 - libfreetype >=2.14.3 - libfreetype6 >=2.14.3 - - xorg-libxcomposite >=0.4.7,<1.0a0 - - openssl >=3.5.6,<4.0a0 - - xcb-util-cursor >=0.1.6,<0.2.0a0 + - wayland >=1.25.0,<2.0a0 + - libzlib >=1.3.2,<2.0a0 - libvulkan-loader >=1.4.341.0,<2.0a0 - - pcre2 >=10.47,<10.48.0a0 - - xorg-libxxf86vm >=1.1.7,<2.0a0 - - icu >=78.3,<79.0a0 + - xorg-libxext >=1.3.7,<2.0a0 + - xcb-util-keysyms >=0.4.1,<0.5.0a0 + - libpng >=1.6.58,<1.7.0a0 + - harfbuzz >=14.2.1 + - xcb-util-cursor >=0.1.6,<0.2.0a0 + - xorg-libxcursor >=1.2.3,<2.0a0 + - libcups >=2.3.3,<2.4.0a0 - libxcb >=1.17.0,<2.0a0 - - wayland >=1.25.0,<2.0a0 - - xorg-libsm >=1.2.6,<2.0a0 - - libtiff >=4.7.1,<4.8.0a0 - - libpq >=18.3,<19.0a0 - - libegl >=1.7.0,<2.0a0 + - libjpeg-turbo >=3.1.4.1,<4.0a0 + - libdrm >=2.4.127,<2.5.0a0 + - xorg-libxcomposite >=0.4.7,<1.0a0 + - xcb-util-image >=0.4.0,<0.5.0a0 - xcb-util-wm >=0.4.2,<0.5.0a0 - - xcb-util-keysyms >=0.4.1,<0.5.0a0 - - libxkbcommon >=1.13.1,<2.0a0 - zstd >=1.5.7,<1.6.0a0 - - fontconfig >=2.17.1,<3.0a0 - - fonts-conda-ecosystem + - krb5 >=1.22.2,<1.23.0a0 + - xcb-util-renderutil >=0.3.10,<0.4.0a0 + - icu >=78.3,<79.0a0 + - xorg-libxdamage >=1.1.6,<2.0a0 + - xorg-libsm >=1.2.6,<2.0a0 + - alsa-lib >=1.2.16,<1.3.0a0 + - openssl >=3.5.6,<4.0a0 + - libglib >=2.88.1,<3.0a0 + - libgl >=1.7.0,<2.0a0 + - libxkbcommon >=1.13.2,<2.0a0 + - libwebp-base >=1.6.0,<2.0a0 + - double-conversion >=3.4.0,<3.5.0a0 + - dbus >=1.16.2,<2.0a0 - xorg-libxtst >=1.2.5,<2.0a0 + - libegl >=1.7.0,<2.0a0 - libxml2 - libxml2-16 >=2.14.6 - - libzlib >=1.3.2,<2.0a0 - - libpng >=1.6.58,<1.7.0a0 - - libcups >=2.3.3,<2.4.0a0 - - libgl >=1.7.0,<2.0a0 - - xcb-util >=0.4.1,<0.5.0a0 constrains: - qt ==6.11.1 license: LGPL-3.0-only license_family: LGPL purls: [] - size: 60185269 - timestamp: 1778597122245 + size: 60185421 + timestamp: 1780593127053 - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.3-h853b02a_0.conda sha256: 12ffde5a6f958e285aa22c191ca01bbd3d6e710aa852e00618fa6ddc59149002 md5: d7d95fc8287ea7bf33e0e7116d2b95ec @@ -3328,10 +3311,10 @@ packages: purls: [] size: 345073 timestamp: 1765813471974 -- conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.15-h6a952e8_1.conda +- conda: https://conda.anaconda.org/conda-forge/linux-64/ruff-0.15.17-h6a952e8_0.conda noarch: python - sha256: 69254aead1c5f6c7e6d7ca195219b655fae4f9d0111ced58b6ceb6cb849cbcd1 - md5: e296d828d3b0cfec4e553ed59c52f17c + sha256: 71fa0408cedb61dac0e94603d01faf42304297ff0d8788c265d853c7712fb644 + md5: 53a10deada836b3a1309b2179f677a6f depends: - python - __glibc >=2.17,<3.0.a0 @@ -3341,9 +3324,9 @@ packages: license: MIT license_family: MIT purls: - - pkg:pypi/ruff?source=compressed-mapping - size: 9174319 - timestamp: 1780055663369 + - pkg:pypi/ruff?source=hash-mapping + size: 9333604 + timestamp: 1781208921649 - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h366c992_103.conda sha256: cafeec44494f842ffeca27e9c8b0c27ed714f93ac77ddadc6aaf726b5554ebac md5: cffd3bdd58090148f4cfcd831f4b26ab @@ -3358,9 +3341,9 @@ packages: purls: [] size: 3301196 timestamp: 1769460227866 -- conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.6-py312h4c3975b_0.conda - sha256: 1a5f2eee536c5ea97dd9674b1b883d8d676ee346f10c8d4ae30626e20aa6d289 - md5: dcbe46475eff6fb9d1adad473c984f39 +- conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.7-py312h4c3975b_0.conda + sha256: f54504d6eeef133ddc2b964b6a021f3faf085bb08bd70debc07f56d6b9b726f1 + md5: 55f526c3fb5302a1ce922612348442e1 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 @@ -3370,11 +3353,11 @@ packages: license_family: Apache purls: - pkg:pypi/tornado?source=compressed-mapping - size: 860785 - timestamp: 1779915943143 -- conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.6-py313h07c4f96_0.conda - sha256: 74c8c049b7b4cc5dfd48d99d12fc833e56fa2a24d8673ee2987c9c36585b7841 - md5: b3dd8f2ac88471c33ddd166ffa4740ad + size: 864705 + timestamp: 1781006801632 +- conda: https://conda.anaconda.org/conda-forge/linux-64/tornado-6.5.7-py313h07c4f96_0.conda + sha256: ed2f17b2bc92389187b47e7444aa7d6ab6c70c3f3ba6bf2ced508027e60e82cb + md5: 9665531f18d2bcbc946e3ba0cf53ea91 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=14 @@ -3384,8 +3367,8 @@ packages: license_family: Apache purls: - pkg:pypi/tornado?source=compressed-mapping - size: 884822 - timestamp: 1779915941038 + size: 885824 + timestamp: 1781006802459 - conda: https://conda.anaconda.org/conda-forge/linux-64/unicodedata2-17.0.1-py312h4c3975b_0.conda sha256: 895bbfe9ee25c98c922799de901387d842d7c01cae45c346879865c6a907f229 md5: 0b6c506ec1f272b685240e70a29261b8 @@ -3499,18 +3482,18 @@ packages: purls: [] size: 51689 timestamp: 1718844051451 -- conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.47-hb03c661_0.conda - sha256: 19c2bb14bec84b0e995b56b752369775c75f1589314b43733948bb5f471a6915 - md5: b56e0c8432b56decafae7e78c5f29ba5 +- conda: https://conda.anaconda.org/conda-forge/linux-64/xkeyboard-config-2.47-h280c20c_1.conda + sha256: 2bd7452f68c39bfff954385b062aca9389262369e318739af270d23af47580a5 + md5: bb1e548a92b0efa12c3e2385ae2d4529 depends: - - __glibc >=2.17,<3.0.a0 - libgcc >=14 + - __glibc >=2.17,<3.0.a0 - xorg-libx11 >=1.8.13,<2.0a0 license: MIT license_family: MIT purls: [] - size: 399291 - timestamp: 1772021302485 + size: 440702 + timestamp: 1781482698093 - conda: https://conda.anaconda.org/conda-forge/linux-64/xorg-libice-1.1.2-hb9d3cd8_0.conda sha256: c12396aabb21244c212e488bbdc4abcdef0b7404b15761d9329f5a4a39113c4b md5: fb901ff28063514abb6046c9ec2c4a45 @@ -3813,6 +3796,7 @@ packages: license_family: MIT purls: - pkg:pypi/charset-normalizer?source=hash-mapping + run_exports: {} size: 58872 timestamp: 1775127203018 - conda: https://conda.anaconda.org/conda-forge/noarch/click-8.2.1-pyh707e725_0.conda @@ -4481,9 +4465,9 @@ packages: purls: [] size: 8325 timestamp: 1764092507920 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.5.0-py313h7208f8c_0.conda - sha256: 3ff16bb31de2cd6699804388337a6efa8a33c12850b5387b13ef38460ca605a3 - md5: 27b54f7bbc635f824806978d9d201573 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/backports.zstd-1.6.0-py313h7208f8c_0.conda + sha256: 4d39bf744249f60212a728369dbc6cd6ec4d5aef6668a14321f747d7eb4bac2d + md5: 6ab3d07883ad437c12a8f5fd90c1df5b depends: - python - __osx >=11.0 @@ -4492,8 +4476,8 @@ packages: license: BSD-3-Clause AND MIT AND EPL-2.0 purls: - pkg:pypi/backports-zstd?source=hash-mapping - size: 243876 - timestamp: 1778594074850 + size: 243873 + timestamp: 1781450811773 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/brotli-1.2.0-h7d5ae5b_1.conda sha256: 422ac5c91f8ef07017c594d9135b7ae068157393d2a119b1908c7e350938579d md5: 48ece20aa479be6ac9a284772827d00c @@ -4578,16 +4562,26 @@ packages: - pkg:pypi/fonttools?source=hash-mapping size: 2983026 timestamp: 1778770717031 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/freetype-2.14.3-hce30654_0.conda - sha256: 5952bd9db12207a18a112e8924aa2ce8c2f9d57b62584d58a97d2f6afe1ea324 - md5: 6dcc75ba2e04c555e881b72793d3282f +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/freetype-2.14.3-hce30654_1.conda + sha256: 96b33f1e2a32c602b167f43719e3acf89ec742b4a1e25e99ffd0e6f99b38d277 + md5: 7bd06ab4ed807154c2d9031eb5ebf025 depends: - - libfreetype 2.14.3 hce30654_0 - - libfreetype6 2.14.3 hdfa99f5_0 + - libfreetype 2.14.3 hce30654_1 + - libfreetype6 2.14.3 hdfa99f5_1 license: GPL-2.0-only OR FTL purls: [] - size: 173313 - timestamp: 1774298702053 + size: 173518 + timestamp: 1780933616544 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/icu-78.3-hef89b57_0.conda + sha256: 3a7907a17e9937d3a46dfd41cffaf815abad59a569440d1e25177c15fd0684e5 + md5: f1182c91c0de31a7abd40cedf6a5ebef + depends: + - __osx >=11.0 + license: MIT + license_family: MIT + purls: [] + size: 12361647 + timestamp: 1773822915649 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/kiwisolver-1.5.0-py313h2af2deb_0.conda sha256: b0ac975a7eb40638b1405c8092835c47222ce758eb26114afee50a8d1ce98569 md5: bd1e04d017f340e42431706402db8b02 @@ -4711,9 +4705,9 @@ packages: purls: [] size: 55420 timestamp: 1761980066242 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_0.conda - sha256: 3133fb6bfa871288b92c8b8752696686a841bf4ffe035aa3038033c9e15b738e - md5: ef22e9ab1dc7c2f334252f565f90b3b8 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libexpat-2.8.1-hf6b4638_1.conda + sha256: 5af74261101e3c777399c6294b2b5d290e508153268eb2e9ff99c4d69834612f + md5: a915151d5d3c5bf039f5ccc8402a436f depends: - __osx >=11.0 constrains: @@ -4721,8 +4715,8 @@ packages: license: MIT license_family: MIT purls: [] - size: 69110 - timestamp: 1779278728511 + size: 69362 + timestamp: 1781203631990 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libffi-3.5.2-hcf2aa1b_0.conda sha256: 6686a26466a527585e6a75cc2a242bf4a3d97d6d6c86424a441677917f28bec7 md5: 43c04d9cb46ef176bb2a4c77e324d599 @@ -4733,28 +4727,28 @@ packages: purls: [] size: 40979 timestamp: 1769456747661 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_0.conda - sha256: a047a2f238362a37d484f9620e8cba29f513a933cd9eb68571ad4b270d6f8f3e - md5: f73b109d49568d5d1dda43bb147ae37f +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype-2.14.3-hce30654_1.conda + sha256: d5637b01941c0fc8f5cbb1f170c238f4ee153b3c1708b9d50f4f1305438ff051 + md5: 0582e67cd14cfed773be2f3b1aba08e0 depends: - libfreetype6 >=2.14.3 license: GPL-2.0-only OR FTL purls: [] - size: 8091 - timestamp: 1774298691258 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_0.conda - sha256: ff764608e1f2839e95e2cf9b243681475f8778c36af7a42b3f78f476fdbb1dd3 - md5: e98ba7b5f09a5f450eca083d5a1c4649 + size: 8365 + timestamp: 1780933612390 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libfreetype6-2.14.3-hdfa99f5_1.conda + sha256: abbfffd8a8c776bb8b59a10c8247fc3aa6b17ba0051e9f6d199dca38479f214f + md5: a0bb0678f67c464938d3693fa96f6884 depends: - __osx >=11.0 - - libpng >=1.6.55,<1.7.0a0 + - libpng >=1.6.58,<1.7.0a0 - libzlib >=1.3.2,<2.0a0 constrains: - freetype >=2.14.3 license: GPL-2.0-only OR FTL purls: [] - size: 338085 - timestamp: 1774298689297 + size: 338442 + timestamp: 1780933611662 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libgcc-15.2.0-hcbb3090_19.conda sha256: 06644fa4d34d57c9e48f4d84b1256f9e5f654fdb37f43acc8a58a396952d42b7 md5: 644058123986582db33aebd4ae2ca184 @@ -4864,16 +4858,17 @@ packages: purls: [] size: 289546 timestamp: 1776315246750 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.1-h1b79a29_0.conda - sha256: 49daec7c83e70d4efc17b813547824bc2bcf2f7256d84061d24fbfe537da9f74 - md5: 6681822ea9d362953206352371b6a904 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/libsqlite-3.53.2-h1ae2325_0.conda + sha256: 862463917e8ef5ac3ebdaf8f19914634b457609cc27ba678b7197124cefeb1f7 + md5: 1ebde5c677f00765233a17e278571177 depends: - __osx >=11.0 + - icu >=78.3,<79.0a0 - libzlib >=1.3.2,<2.0a0 license: blessing purls: [] - size: 920047 - timestamp: 1777987051643 + size: 927724 + timestamp: 1780575223548 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/libtiff-4.7.1-h4030677_1.conda sha256: e9248077b3fa63db94caca42c8dbc6949c6f32f94d1cafad127f9005d9b1507f md5: e2a72ab2fa54ecb6abab2b26cde93500 @@ -4937,6 +4932,7 @@ packages: - openmp 22.1.7|22.1.7.* - intel-openmp <0.0a0 license: Apache-2.0 WITH LLVM-exception + license_family: APACHE purls: [] size: 285162 timestamp: 1780455637760 @@ -5040,17 +5036,17 @@ packages: purls: [] size: 319697 timestamp: 1772625397692 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.2-hd24854e_0.conda - sha256: c91bf510c130a1ea1b6ff023e28bac0ccaef869446acd805e2016f69ebdc49ea - md5: 25dcccd4f80f1638428613e0d7c9b4e1 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/openssl-3.6.3-hd24854e_0.conda + sha256: b3e3ca895c336d4eb91c5d2f244a312bdb59a0de8cfa0cc4c179225ab2f6bbfb + md5: 8187a86242741725bfa74785fe812979 depends: - __osx >=11.0 - ca-certificates license: Apache-2.0 license_family: Apache purls: [] - size: 3106008 - timestamp: 1775587972483 + size: 3102584 + timestamp: 1781069820667 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pillow-12.2.0-py313h45e5a15_0.conda sha256: 90333643a7868b10724999633bb393d005bc5f539d05666f80c41fb67e5f0f3f md5: 6186601fd72a394a6f7c7b7096f6a063 @@ -5084,29 +5080,29 @@ packages: purls: [] size: 8381 timestamp: 1726802424786 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.13-h20e6be0_100_cp313.conda +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/python-3.13.14-h448ec07_100_cp313.conda build_number: 100 - sha256: d0fffc5fde21d1ae350da545dfb9e115a8c53bed8a9c5761f9efd4a5581853c1 - md5: 9991a930e81d3873eba7a299ba783ec4 + sha256: c89eedab6b293fae654d75483d8f3e5eb3ff9ce2478134d902676c1dd20c7dfd + md5: e556c07deaa168043f8430bb046092e2 depends: - __osx >=11.0 - bzip2 >=1.0.8,<2.0a0 - - libexpat >=2.7.5,<3.0a0 + - libexpat >=2.8.1,<3.0a0 - libffi >=3.5.2,<3.6.0a0 - - liblzma >=5.8.2,<6.0a0 + - liblzma >=5.8.3,<6.0a0 - libmpdec >=4.0.0,<5.0a0 - - libsqlite >=3.52.0,<4.0a0 + - libsqlite >=3.53.2,<4.0a0 - libzlib >=1.3.2,<2.0a0 - - ncurses >=6.5,<7.0a0 - - openssl >=3.5.6,<4.0a0 + - ncurses >=6.6,<7.0a0 + - openssl >=3.5.7,<4.0a0 - python_abi 3.13.* *_cp313 - readline >=8.3,<9.0a0 - tk >=8.6.13,<8.7.0a0 - tzdata license: Python-2.0 purls: [] - size: 12966447 - timestamp: 1775615694085 + size: 17017633 + timestamp: 1781257915644 python_site_packages_path: lib/python3.13/site-packages - conda: https://conda.anaconda.org/conda-forge/osx-arm64/pyyaml-6.0.3-py313h65a2061_1.conda sha256: 950725516f67c9691d81bb8dde8419581c5332c5da3da10c9ba8cbb1698b825d @@ -5144,10 +5140,10 @@ packages: purls: [] size: 313930 timestamp: 1765813902568 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.15-h80928e0_1.conda +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.15.17-h80928e0_0.conda noarch: python - sha256: 770d6f74f247b02f4cf8bc6f7066bac178746fbfabea12f2675aea20c50ba9c6 - md5: cbe46e26504a93010089bc3d3ed636aa + sha256: 56035432a678aec0aaa2195b282fd3a6649e4c9909641cad3962c387c3bf17f9 + md5: c39fd9e9b62acbfe880a8f07f4ff9b77 depends: - python - __osx >=11.0 @@ -5157,8 +5153,8 @@ packages: license_family: MIT purls: - pkg:pypi/ruff?source=compressed-mapping - size: 8424737 - timestamp: 1780055998652 + size: 8596345 + timestamp: 1781212277141 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/tk-8.6.13-h010d191_3.conda sha256: 799cab4b6cde62f91f750149995d149bc9db525ec12595e8a1d91b9317f038b3 md5: a9d86bc62f39b94c4661716624eb21b0 @@ -5170,9 +5166,9 @@ packages: purls: [] size: 3127137 timestamp: 1769460817696 -- conda: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.5.6-py313h0997733_0.conda - sha256: b181f2e76169dc7c3df9e958e1cd55195974715b8180ef251c6b26ea74cbb442 - md5: a78d9ca10694bb78efb4026d57906882 +- conda: https://conda.anaconda.org/conda-forge/osx-arm64/tornado-6.5.7-py313h0997733_0.conda + sha256: 02c7b05be4d74da0d7329f92445646e3489634f0c3e59b26efefd846662ac20a + md5: dbc95e65c936251c3d32111c867d84e1 depends: - __osx >=11.0 - python >=3.13,<3.14.0a0 @@ -5182,8 +5178,8 @@ packages: license_family: Apache purls: - pkg:pypi/tornado?source=compressed-mapping - size: 884030 - timestamp: 1779916438404 + size: 889689 + timestamp: 1781007967544 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/watchdog-6.0.0-py313h6688731_3.conda sha256: 306ef968c4d5fef96430f6169c819366baa36c7c3e701933282fca43cba67548 md5: 90e28468c1fb2ade60be2535cbe25e5d @@ -5265,30 +5261,14 @@ packages: - flax - ml-collections - casadi - - drone-models==0.1.0 - - drone-controllers==0.2.0 + - array-api-compat + - array-api-extra - jax[cuda12] ; extra == 'gpu' - fire ; extra == 'benchmark' - matplotlib ; extra == 'benchmark' - pandas ; extra == 'benchmark' + - pyinstrument ; extra == 'benchmark' requires_python: '>=3.11,<3.14' -- pypi: ./submodules/drone-controllers - name: drone-controllers - requires_dist: - - numpy>=2.0.0 - - scipy>=1.17.0 - - array-api-compat - - array-api-extra -- pypi: ./submodules/drone-models - name: drone-models - requires_dist: - - numpy>=2.0.0 - - scipy>=1.17.0 - - casadi>=3.7.0 - - array-api-compat - - array-api-extra - - matplotlib ; extra == 'sysid' - - jax>=0.7 ; extra == 'sysid' - pypi: https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl name: scipy version: 1.17.1 @@ -5420,6 +5400,11 @@ packages: - virtualenv>=20.17 ; python_full_version >= '3.10' and python_full_version < '3.14' and extra == 'virtualenv' - virtualenv>=20.31 ; python_full_version >= '3.14' and extra == 'virtualenv' requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/10/03/8aeeb7458d22546bf64b5250ca1daeb5ff757d900e8e4a7476c6f0db843e/protobuf-7.35.1-cp310-abi3-macosx_10_9_universal2.whl + name: protobuf + version: 7.35.1 + sha256: 24f857477359a85c0c235261b8ba905fd51b2562f4a64ca1df5473f29850cbf6 + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/11/8c/c9138d881c79aa0ea9ed83cbd58d5ca75624378b38cee225dcf5c42cc91f/griffelib-2.0.2-py3-none-any.whl name: griffelib version: 2.0.2 @@ -5485,6 +5470,22 @@ packages: - numpy>=2.0 - ml-dtypes>=0.5.0 requires_python: '>=3.11' +- pypi: https://files.pythonhosted.org/packages/20/2c/0622f20ff02b2ef32558733443805dc82fd4c275be01b2d19d14676f3a1b/cryptography-49.0.0-cp311-abi3-manylinux_2_28_x86_64.whl + name: cryptography + version: 49.0.0 + sha256: 2afe9051da7ae7bd5905da5a949280c7d2bb75682e188f650a9d0f2756b834c6 + requires_dist: + - cffi>=2.0.0 ; platform_python_implementation != 'PyPy' + - typing-extensions>=4.13.2 ; python_full_version < '3.11' + - bcrypt>=3.1.5 ; extra == 'ssh' + requires_python: '!=3.9.0,>=3.9,!=3.9.1' +- pypi: https://files.pythonhosted.org/packages/22/6a/3aa1055b4a5dc3195e79687bbe4fb2188e400c44c181b5843de81fee7553/array_api_extra-0.11.0-py3-none-any.whl + name: array-api-extra + version: 0.11.0 + sha256: 84e7349176ec0b2e03000f1a6f9c88556e29ec959e0de5ef21dc5925e4b44a05 + requires_dist: + - array-api-compat>=1.14.0,<2 + requires_python: '>=3.11' - pypi: https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl name: opt-einsum version: 3.4.0 @@ -5523,16 +5524,6 @@ packages: - mkdocs>=1.1 - pymdown-extensions>=9.2 requires_python: '>=3.7' -- pypi: https://files.pythonhosted.org/packages/32/28/79f0f8de97cce916d5ae88a7bee1ad724855e83e6019c0b4d5b3fabc80f3/mkdocstrings_python-2.0.3-py3-none-any.whl - name: mkdocstrings-python - version: 2.0.3 - sha256: 0b83513478bdfd803ff05aa43e9b1fca9dd22bcd9471f09ca6257f009bc5ee12 - requires_dist: - - mkdocstrings>=0.30 - - mkdocs-autorefs>=1.4 - - griffelib>=2.0 - - typing-extensions>=4.0 ; python_full_version < '3.11' - requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/32/91/30151a39f7570f448ed84529390628a651d7f27c87d73c9b887f8189695e/docutils-0.23-py3-none-any.whl name: docutils version: '0.23' @@ -5569,6 +5560,11 @@ packages: - numpy>=2.0 - ml-dtypes>=0.5.0 requires_python: '>=3.11' +- pypi: https://files.pythonhosted.org/packages/38/ed/b7728573156d70b6b094233b0f38d876fc37340826cf852347ec2c7ca8ca/msgpack-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl + name: msgpack + version: 1.2.0 + sha256: a0d94420d9d52c56568159a69200af7e45eadb29615fa9d09fada140de1c38c7 + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/3a/13/547360d81e6d88d58492968ffda9f9542854f11310ee556fef14260cc886/zipp-4.1.0-py3-none-any.whl name: zipp version: 4.1.0 @@ -5626,6 +5622,13 @@ packages: - pylint>=2.6.0 ; extra == 'dev' - pyink ; extra == 'dev' requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/3e/85/1c12e849e4d50624e75496378a3fb168389f768d3ec7cb694fba873ff9a8/nvidia_nvshmem_cu12-3.7.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl + name: nvidia-nvshmem-cu12 + version: 3.7.0 + sha256: ca643cb87a214c0f7ad8396def747adcaa0c8dfb0cb7e5012338ac3b0d76404b + requires_dist: + - nvidia-cuda-cccl-cu12 + requires_python: '>=3' - pypi: https://files.pythonhosted.org/packages/3f/51/d4db610ef29373b879047326cbf6fa98b6c1969d6f6dc423279de2b1be2c/requests_toolbelt-1.0.0-py2.py3-none-any.whl name: requests-toolbelt version: 1.0.0 @@ -5740,13 +5743,6 @@ packages: version: 12.9.86 sha256: e3f1171dbdc83c5932a45f0f4c99180a70de9bd2718c1ab77d14104f6d7147f9 requires_python: '>=3' -- pypi: https://files.pythonhosted.org/packages/46/f7/9e14be985fd77ae26fee9136c9735e8987772e0ecf5f1f4e6e2b84cadc46/array_api_extra-0.10.1-py3-none-any.whl - name: array-api-extra - version: 0.10.1 - sha256: 9c2003079ccd2a0c92b1cf797b5867b9d7ea9428e75f70c7f78c1c0842d54368 - requires_dist: - - array-api-compat>=1.13.0,<2 - requires_python: '>=3.11' - pypi: https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl name: imageio version: 2.37.3 @@ -5833,6 +5829,11 @@ packages: - mkdocs>=1.4.1,<=1.6.1 - properdocs>=1.6.5 requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/50/32/e7ffa9c324ae260e5dbb4af2cd557bf7a8d155c8ac7b79a785fe1796fb92/nvidia_nccl_cu12-2.30.7-py3-none-manylinux_2_18_x86_64.whl + name: nvidia-nccl-cu12 + version: 2.30.7 + sha256: 8ce1b8213f61f2bfac132e6df890af6450b77cbd140c6ce4e98cb0c2d8e678c9 + requires_python: '>=3' - pypi: https://files.pythonhosted.org/packages/5a/3d/589663aeeacd59bb2f3e8596bfd3e81cf0fb18d70bb433199041f469771b/etils-1.14.0-py3-none-any.whl name: etils version: 1.14.0 @@ -5941,11 +5942,6 @@ packages: - kubernetes ; extra == 'k8s' - xprof ; extra == 'xprof' requires_python: '>=3.11' -- pypi: https://files.pythonhosted.org/packages/5d/ba/459f18c16f2b3fc1a1ca871f72f07d70c07bf768ad0a507a698b8052ac58/msgpack-1.1.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - name: msgpack - version: 1.1.2 - sha256: fac4be746328f90caa3cd4bc67e6fe36ca2bf61d5c6eb6d895b6527e3f05071e - requires_python: '>=3.9' - pypi: https://files.pythonhosted.org/packages/5e/c6/82669e70cef67c803852285ba6f59d7e3d102983c0ab4be8269c14756677/tensorstore-0.1.84-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl name: tensorstore version: 0.1.84 @@ -5954,6 +5950,16 @@ packages: - numpy>=1.22.0 - ml-dtypes>=0.5.0 requires_python: '>=3.11' +- pypi: https://files.pythonhosted.org/packages/5e/e3/00ec594aef5f55522e6d373bc2ac53e53a8f5e9ae32f2d6854b0de4270f3/mkdocstrings_python-2.0.4-py3-none-any.whl + name: mkdocstrings-python + version: 2.0.4 + sha256: fd87c173e1e719a85997b6d4f852cdc55f36710e0ed08da3a7bd9abe79c9db00 + requires_dist: + - mkdocstrings>=0.30 + - mkdocs-autorefs>=1.4 + - griffelib>=2.0 + - typing-extensions>=4.0 ; python_full_version < '3.11' + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/5f/6f/e62b4dfc7ad6518e7eff2516f680d02a0f6eb62c0c212e152ca708a0085e/uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl name: uvloop version: 0.22.1 @@ -5971,11 +5977,6 @@ packages: - sphinxcontrib-asyncio~=0.3.0 ; extra == 'docs' - sphinx-rtd-theme~=0.5.2 ; extra == 'docs' requires_python: '>=3.8.1' -- pypi: https://files.pythonhosted.org/packages/65/92/a5100f7185a800a5d29f8d14041f61475b9de465ffcc0f3b9fba606e4505/msgpack-1.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl - name: msgpack - version: 1.1.2 - sha256: 372839311ccf6bdaf39b00b61288e0557916c3729529b301c52c2d88842add42 - requires_python: '>=3.9' - pypi: https://files.pythonhosted.org/packages/65/b6/09b01cdbc15224e2850365192d17b7bdebb8bdbd8780ed221fcdf0d9a515/pandas-3.0.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl name: pandas version: 3.0.3 @@ -6066,11 +6067,6 @@ packages: - xlsxwriter>=3.2.0 ; extra == 'all' - zstandard>=0.23.0 ; extra == 'all' requires_python: '>=3.11' -- pypi: https://files.pythonhosted.org/packages/6b/c3/0e45ff4dce8401f6ea7c25d80d75738813a47f5ae2691e2478f2fd1e5e93/nvidia_nccl_cu12-2.30.4-py3-none-manylinux_2_18_x86_64.whl - name: nvidia-nccl-cu12 - version: 2.30.4 - sha256: 040974b261edec4b8b793e59e92ab7176fe4ab4bc61b800f9f3bfaeec2d436f3 - requires_python: '>=3' - pypi: https://files.pythonhosted.org/packages/6e/94/be70f8ee9c45f2f62b39a1f0e9303bc20e138a8f3b8e50ffd89498e177e1/mkdocstrings-1.0.4-py3-none-any.whl name: mkdocstrings version: 1.0.4 @@ -6086,22 +6082,10 @@ packages: - mkdocstrings-python-legacy>=0.2.1 ; extra == 'python-legacy' - mkdocstrings-python>=1.16.2 ; extra == 'python' requires_python: '>=3.10' -- pypi: https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl - name: protobuf - version: 7.35.0 - sha256: 6c0f98f10c8a05ea30f8993dfef2de093d27b490fdae78bb60c8343795d55011 - requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/7c/f0/21f81892e4ed10f4ec3ef2e7cf8635fb76e7c0907c55d0da66be50094760/farama_notifications-0.0.6-py3-none-any.whl name: farama-notifications version: 0.0.6 sha256: f84839188efa1ce5bb361c2a84881b2dc2c0d0d7fb661ff00421820170930935 -- pypi: https://files.pythonhosted.org/packages/7d/9d/1a383211b0967e702b9e84643986fb31bf35ca07bddc19e0cf139fd3291d/nvidia_cudnn_cu12-9.23.0.39-py3-none-manylinux_2_27_x86_64.whl - name: nvidia-cudnn-cu12 - version: 9.23.0.39 - sha256: 89d53e2a2b0614278afbeda67ac89594bdd74f9f283f22f2d34409d55859846f - requires_dist: - - nvidia-cublas-cu12 - requires_python: '>=3' - pypi: https://files.pythonhosted.org/packages/7f/66/b15ce62552d84bbfcec9a4873ab79d993a1dd4edb922cbfccae192bd5b5f/jaraco.classes-3.4.0-py3-none-any.whl name: jaraco-classes version: 3.4.0 @@ -6204,16 +6188,16 @@ packages: - markdown-it-py>=2.2.0 - pygments>=2.13.0,<3.0.0 requires_python: '>=3.9.0' -- pypi: https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl - name: protobuf - version: 7.35.0 - sha256: 66be6c513931c794fa92c080ffee41671390da3d79da219cf9c0c0907f035dda - requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/85/30/d162e99746a2fb1d98bb0ef23af3e201b156cf09f7de867c7390c8fe1c06/nh3-0.3.5-cp38-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl name: nh3 version: 0.3.5 sha256: 3bb854485c9b33e5bb143ff3e49e577073bc6bc320f0ff8fc316dd89c0d3c101 requires_python: '>=3.8' +- pypi: https://files.pythonhosted.org/packages/86/16/1a8fd2b19544b84575cf84ef7aa3ad4c173b756d5f087c91f85d1b295777/array_api_compat-1.15.0-py3-none-any.whl + name: array-api-compat + version: 1.15.0 + sha256: 7b1b9c53269061403fd5f45a8de349f16e7887653328bfa0c5f2d45299ff0a8e + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl name: cloudpickle version: 3.1.2 @@ -6360,20 +6344,30 @@ packages: - pre-commit>=3.8.0 ; extra == 'dev' - scikit-build-core[pyproject]>=0.11.0 ; extra == 'dev' requires_python: '>=3.11' -- pypi: https://files.pythonhosted.org/packages/92/dc/c385f38f2c2433333345a82926c6bfa5ecfff3ef787201614317b58dd8be/msgpack-1.1.2-cp313-cp313-macosx_11_0_arm64.whl - name: msgpack - version: 1.1.2 - sha256: 42eefe2c3e2af97ed470eec850facbe1b5ad1d6eacdbadc42ec98e7dcf68b4b7 - requires_python: '>=3.9' -- pypi: https://files.pythonhosted.org/packages/95/38/0d29a6fd7d0d1373f0c0c88a04ba20e359b257753ac497564cd660fc1d55/cryptography-48.0.0-cp311-abi3-manylinux_2_28_x86_64.whl - name: cryptography - version: 48.0.0 - sha256: a0e692c683f4df67815a2d258b324e66f4738bd7a96a218c826dce4f4bd05d8f +- pypi: https://files.pythonhosted.org/packages/91/4d/2ca3ca9906ce6e05070f431c54d54ccbaf57a980cfa58032d35b0b0ac1f8/pyinstrument-5.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl + name: pyinstrument + version: 5.1.2 + sha256: 12af1e83795b6c640d657d339014dd1ff718b182dec736d7d1f1d8a97534eb53 requires_dist: - - cffi>=2.0.0 ; platform_python_implementation != 'PyPy' - - typing-extensions>=4.13.2 ; python_full_version < '3.11' - - bcrypt>=3.1.5 ; extra == 'ssh' - requires_python: '!=3.9.0,>=3.9,!=3.9.1' + - pytest ; extra == 'test' + - flaky ; extra == 'test' + - trio ; extra == 'test' + - cffi>=1.17.0 ; extra == 'test' + - greenlet>=3 ; extra == 'test' + - pytest-asyncio==0.23.8 ; extra == 'test' + - ipython ; extra == 'test' + - click ; extra == 'bin' + - nox ; extra == 'bin' + - sphinx==7.4.7 ; extra == 'docs' + - myst-parser==3.0.1 ; extra == 'docs' + - furo==2024.7.18 ; extra == 'docs' + - sphinxcontrib-programoutput==0.17 ; extra == 'docs' + - sphinx-autobuild==2024.4.16 ; extra == 'docs' + - numpy ; extra == 'examples' + - django ; extra == 'examples' + - litestar ; extra == 'examples' + - typing-extensions ; extra == 'types' + requires_python: '>=3.8' - pypi: https://files.pythonhosted.org/packages/95/f4/61e6996dd20481ee834f57a8e9dca28b1869366a135e0d42e2aa8493bdd4/nvidia_cufft_cu12-11.4.1.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl name: nvidia-cufft-cu12 version: 11.4.1.4 @@ -6401,6 +6395,16 @@ packages: - pytest-enabler>=3.4 ; extra == 'enabler' - pytest-mypy>=1.0.1 ; platform_python_implementation != 'PyPy' and extra == 'type' requires_python: '>=3.10' +- pypi: https://files.pythonhosted.org/packages/97/1b/295bf2fa3e740131778065e5ffa2c481f0e7210182d408e9a2c244ff5b0c/readme_renderer-45.0-py3-none-any.whl + name: readme-renderer + version: '45.0' + sha256: 3385ed220117104a2bceb4a9dac8c5fdf6d1f96890d7ea2a9c7174fd5c84091f + requires_dist: + - nh3>=0.2.14 + - docutils>=0.21.2 + - pygments>=2.5.1 + - comrak>=0.0.11 ; extra == 'md' + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/97/df/a1495de78c1da3e8e93978dd177b04d18aaa7361452e30a3467c41c3b19e/mujoco-3.9.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl name: mujoco version: 3.9.0 @@ -6436,40 +6440,13 @@ packages: version: 4.1.1 sha256: 249e2e220aa6d9b9d936bde84eb7bf79d5b6c5a8273c6e411f8b1635a9073f2d requires_python: '>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*' -- pypi: https://files.pythonhosted.org/packages/9e/da/36fa8307cc40889307fed415d70b67d35ec330ffce889a9c03cf8f616cfa/nvidia_nvshmem_cu12-3.6.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - name: nvidia-nvshmem-cu12 - version: 3.6.5 - sha256: f86db35f1ced21a790fa255dcae7db8998bf8655a95e76c033a6574190b398e4 +- pypi: https://files.pythonhosted.org/packages/a0/7e/e0a5d44bf070a1ff945050abc02ef1cff5ca9c6ab5dc6a16ab6322593a32/nvidia_cudnn_cu12-9.23.1.3-py3-none-manylinux_2_27_x86_64.whl + name: nvidia-cudnn-cu12 + version: 9.23.1.3 + sha256: 272d4815eef8f0dd21ecca768bfa18a618fb76ec31547dd0885cd49e76ebcd1d requires_dist: - - nvidia-cuda-cccl-cu12 + - nvidia-cublas-cu12 requires_python: '>=3' -- pypi: https://files.pythonhosted.org/packages/a0/d3/54cd560804a8c2b898824778e86c13c2a14600bc83532a9c4f69f2f469c3/array_api_compat-1.14.0-py3-none-any.whl - name: array-api-compat - version: 1.14.0 - sha256: ed5af1f9b6595a199c942505f281ec994892556b6efc24679a0501e87a7d6279 - requires_dist: - - cupy ; extra == 'cupy' - - dask>=2024.9.0 ; extra == 'dask' - - jax ; extra == 'jax' - - numpy>=1.22 ; extra == 'numpy' - - torch ; extra == 'pytorch' - - sparse>=0.15.1 ; extra == 'sparse' - - ndonnx ; extra == 'ndonnx' - - furo ; extra == 'docs' - - linkify-it-py ; extra == 'docs' - - myst-parser ; extra == 'docs' - - sphinx ; extra == 'docs' - - sphinx-copybutton ; extra == 'docs' - - sphinx-autobuild ; extra == 'docs' - - array-api-strict ; extra == 'dev' - - dask[array]>=2024.9.0 ; extra == 'dev' - - jax[cpu] ; extra == 'dev' - - ndonnx ; extra == 'dev' - - numpy>=1.22 ; extra == 'dev' - - pytest ; extra == 'dev' - - torch ; extra == 'dev' - - sparse>=0.15.1 ; extra == 'dev' - requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/ab/8a/18d4ff2c7bd83f30d6924bd4ad97abf418488c3f908dea228d6f0961ad68/ml_collections-1.1.0-py3-none-any.whl name: ml-collections version: 1.1.0 @@ -6675,6 +6652,11 @@ packages: - sqlalchemy>=1.4.0 ; extra == 'tiering-service' - uvloop ; extra == 'tiering-service' requires_python: '>=3.11' +- pypi: https://files.pythonhosted.org/packages/c9/59/7e6b812629d2f919e586041bffc130e1af32079f71bb20699eed54ed6d92/msgpack-1.2.0-cp313-cp313-macosx_11_0_arm64.whl + name: msgpack + version: 1.2.0 + sha256: 581e317112260d8ca488d490cad9290a5682276f309c41c7de237a85ed8799c8 + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/cb/c0/0a517bfe63ccd3b92eb254d264e28fca3c7cab75d07daea315250fb1bf73/nvidia_cublas_cu12-12.9.2.10-py3-none-manylinux_2_27_x86_64.whl name: nvidia-cublas-cu12 version: 12.9.2.10 @@ -6842,16 +6824,11 @@ packages: name: pyopengl version: 3.1.10 sha256: 794a943daced39300879e4e47bd94525280685f42dbb5a998d336cfff151d74f -- pypi: https://files.pythonhosted.org/packages/e1/67/921ec3024056483db83953ae8e48079ad62b92db7880013ca77632921dd0/readme_renderer-44.0-py3-none-any.whl - name: readme-renderer - version: '44.0' - sha256: 2fbca89b81a08526aadf1357a8c2ae889ec05fb03f5da67f9769c9a592166151 - requires_dist: - - nh3>=0.2.14 - - docutils>=0.21.2 - - pygments>=2.5.1 - - cmarkgfm>=0.8.0 ; extra == 'md' - requires_python: '>=3.9' +- pypi: https://files.pythonhosted.org/packages/e4/be/5b3cfe508bfab6761414ff944e3366eb13be4fd71efcd69450f89ba39f43/protobuf-7.35.1-cp310-abi3-manylinux2014_x86_64.whl + name: protobuf + version: 7.35.1 + sha256: 74758715c53d7158fb76caf4f0cfdacc5329a4b1bb994f865d6cf302d413a1c4 + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl name: fire version: 0.7.1 @@ -7072,6 +7049,11 @@ packages: name: jax-cuda12-pjrt version: 0.10.1 sha256: 4c50a469f1b7c2fbba278d5b6932fe33de41f833b333cae28151422ec456857d +- pypi: https://files.pythonhosted.org/packages/fb/63/68f5d0ea81e167db5f59ddb94dc6f837667062113feff1c73fabf8907061/msgpack-1.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl + name: msgpack + version: 1.2.0 + sha256: a186027e4279efa4c8bf06ce30605498d7d0d3af0fba0b9799dce85a3fd4a93c + requires_python: '>=3.10' - pypi: https://files.pythonhosted.org/packages/ff/9a/9afaade874b2fa6c752c36f1548f718b5b83af81ed9b76628329dab81c1b/rfc3986-2.0.0-py2.py3-none-any.whl name: rfc3986 version: 2.0.0 diff --git a/properdocs.yml b/properdocs.yml index 35f86c7..6211cfd 100644 --- a/properdocs.yml +++ b/properdocs.yml @@ -16,7 +16,6 @@ theme: - content.code.annotate - navigation.footer - navigation.indexes - - navigation.sections - navigation.tabs - navigation.top - navigation.tracking @@ -49,8 +48,20 @@ nav: - Simulator Overview: user-guide/sim-overview.md - Object-Oriented API: user-guide/oo-api.md - Functional API: user-guide/functional-api.md - - Physics Models: user-guide/physics-models.md - - Control Modes: user-guide/control-modes.md + - Dynamics: + - user-guide/dynamics/index.md + - Dynamics functions: user-guide/dynamics/dynamics-functions.md + - Parametrization: user-guide/dynamics/parametrize.md + - Batching & domain randomization: user-guide/dynamics/batching.md + - Symbolic dynamics: user-guide/dynamics/symbolic.md + - System identification: user-guide/dynamics/system-identification.md + - Control Modes: + - user-guide/control/index.md + - Mellinger controller: user-guide/control/mellinger.md + - Parametrization: user-guide/control/parametrize.md + - Integral errors: user-guide/control/integral-errors.md + - Batching: user-guide/control/batching.md + - JIT compilation: user-guide/control/jit.md - Pipelines: user-guide/pipelines.md - Visualization: user-guide/visualization.md - MuJoCo Integration: user-guide/mujoco.md diff --git a/pyproject.toml b/pyproject.toml index 8a54bdf..e166cf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,14 +30,14 @@ dependencies = [ "flax", "ml_collections", "casadi", - "drone-models==0.1.0", # from PyPI (regular install) or submodules folder (pixi install) - "drone-controllers==0.2.0", # from PyPI (regular install) or submodules folder (pixi install) + "array-api-compat", + "array-api-extra", ] requires-python = ">=3.11,<3.14" # MuJoCo has no wheels for python 14 [project.optional-dependencies] gpu = ["jax[cuda12]"] -benchmark = ["fire", "matplotlib", "pandas"] +benchmark = ["fire", "matplotlib", "pandas", "pyinstrument"] [project.urls] Homepage = "https://github.com/learnsyslab/crazyflow" @@ -51,6 +51,11 @@ version = { attr = "crazyflow.__version__" } [tool.setuptools.package-data] crazyflow = ["scene.xml"] +"crazyflow.drones" = ["*.xml", "params.toml", "assets/*/*.stl"] +"crazyflow.dynamics.first_principles" = ["params.toml"] +"crazyflow.dynamics.so_rpy" = ["params.toml"] +"crazyflow.dynamics.so_rpy_rotor" = ["params.toml"] +"crazyflow.dynamics.so_rpy_rotor_drag" = ["params.toml"] [tool.pytest.ini_options] markers = ["unit", "integration", "render"] @@ -124,8 +129,6 @@ ruff = "*" [tool.pixi.pypi-dependencies] crazyflow = { path = "./", editable = true } -drone-models = { path = "./submodules/drone-models", editable = true } -drone-controllers = { path = "./submodules/drone-controllers", editable = true } [tool.pixi.feature.gpu] platforms = ["linux-64"] diff --git a/submodules/drone-controllers b/submodules/drone-controllers deleted file mode 160000 index a5f39c9..0000000 --- a/submodules/drone-controllers +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a5f39c9a6b3efb75420f146cadf270d415cc1ae2 diff --git a/submodules/drone-models b/submodules/drone-models deleted file mode 160000 index 51076d8..0000000 --- a/submodules/drone-models +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 51076d84485e4935013cbfcb475058f8f3ec461b diff --git a/tests/integration/dynamics/__init__.py b/tests/integration/dynamics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/dynamics/test_identification_pipeline.py b/tests/integration/dynamics/test_identification_pipeline.py new file mode 100644 index 0000000..03cdd4e --- /dev/null +++ b/tests/integration/dynamics/test_identification_pipeline.py @@ -0,0 +1,57 @@ +"""Tests for identification pipeline.""" + +from __future__ import annotations + +import array_api_compat.numpy as np +import pytest + +from crazyflow.dynamics.utils.data_utils import derivatives_svf, preprocessing +from crazyflow.dynamics.utils.identification import sys_id_rotation, sys_id_translation + + +@pytest.mark.integration +def test_sys_id_rotation(): + time = np.linspace(0.0, 1.0, 100) + pos = np.stack((np.cos(2 * np.pi * time), np.sin(2 * np.pi * time), time), axis=-1) + quat = np.zeros((100, 4)) + quat[:, 3] = 1.0 # No rotation + cmd_rpy = np.zeros((100, 3)) + cmd_f = np.ones(100) * 0.03 + data = { + "time": np.linspace(0.0, 1.0, 100), + "pos": pos, + "quat": quat, + "cmd_rpy": cmd_rpy, + "cmd_f": cmd_f, + } + + data = preprocessing(data) + data["rpy"] = np.roll(cmd_rpy, shift=10, axis=0) + data = derivatives_svf(data) + + sys_id_rotation(data) + + +@pytest.mark.integration +@pytest.mark.parametrize("dynamics", ["so_rpy", "so_rpy_rotor", "so_rpy_rotor_drag"]) +def test_sys_id_translation(dynamics: str): + mass = 0.03 + time = np.linspace(0.0, 1.0, 100) + phi = 2 * np.pi * time + pos = np.stack((np.cos(phi) * 0.1, np.sin(phi) * 0.1, np.roll(np.cos(phi), 10)), axis=-1) + quat = np.zeros((100, 4)) + quat[:, 3] = 1.0 # No rotation + cmd_rpy = np.zeros((100, 3)) + cmd_f = (np.ones(100) + np.cos(phi)) * mass * 9.81 + data = { + "time": np.linspace(0.0, 1.0, 100), + "pos": pos, + "quat": quat, + "cmd_rpy": cmd_rpy, + "cmd_f": cmd_f, + } + + data = preprocessing(data) + data = derivatives_svf(data) + + sys_id_translation(dynamics=dynamics, mass=mass, data=data) diff --git a/tests/integration/test_disturbance.py b/tests/integration/test_disturbance.py index dc2d691..359c958 100644 --- a/tests/integration/test_disturbance.py +++ b/tests/integration/test_disturbance.py @@ -2,9 +2,9 @@ import numpy as np import pytest -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim from crazyflow.sim.data import SimData -from crazyflow.sim.pipeline import insert_fn_after +from crazyflow.sim.pipeline import append_fn def disturbance_fn(data: SimData) -> SimData: @@ -15,10 +15,10 @@ def disturbance_fn(data: SimData) -> SimData: return data.replace(states=states, core=data.core.replace(rng_key=key)) -@pytest.mark.parametrize("physics", Physics) +@pytest.mark.parametrize("dynamics", Dynamics) @pytest.mark.integration -def test_disturbance(physics: Physics): - sim = Sim(n_worlds=2, n_drones=3, control="state", physics=physics) +def test_disturbance(dynamics: Dynamics): + sim = Sim(n_worlds=2, n_drones=3, control="state", dynamics=dynamics) control = np.zeros((sim.n_worlds, sim.n_drones, 13)) control[..., :3] = 1.0 n_steps = 10 @@ -30,7 +30,7 @@ def test_disturbance(physics: Physics): pos.append(sim.data.states.pos[0, 0]) sim.reset() - insert_fn_after(sim.step_pipeline, "step_state_controller", disturbance_fn) + append_fn(sim.step_pipeline, disturbance_fn, name="disturbance") sim.build_step_fn() for _ in range(n_steps): sim.state_control(control) diff --git a/tests/integration/test_interfaces.py b/tests/integration/test_interfaces.py index fdaba7e..24109e4 100644 --- a/tests/integration/test_interfaces.py +++ b/tests/integration/test_interfaces.py @@ -1,20 +1,19 @@ import jax import numpy as np import pytest -from drone_controllers import parametrize -from drone_controllers.mellinger import state2attitude -from drone_models.core import load_params -from drone_models.transform import motor_force2rotor_vel from scipy.spatial.transform import Rotation as R -from crazyflow.control.control import Control -from crazyflow.sim import Physics, Sim +from crazyflow.control import Control, parametrize +from crazyflow.control.core import load_params +from crazyflow.control.mellinger import force_torque2rotor_vel, state2attitude +from crazyflow.control.transform import motor_force2rotor_vel +from crazyflow.sim import Dynamics, Sim @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -def test_state_interface(physics: Physics): - sim = Sim(physics=physics, control=Control.state) +@pytest.mark.parametrize("dynamics", Dynamics) +def test_state_interface(dynamics: Dynamics): + sim = Sim(dynamics=dynamics, control=Control.state) # Simple P controller for attitude to reach target height cmd = np.zeros((1, 1, 13), dtype=np.float32) @@ -28,37 +27,38 @@ def test_state_interface(physics: Physics): # Check if drone reached target position distance = np.linalg.norm(sim.data.states.pos[0, 0] - np.array([0.0, 0.0, 1.0])) - assert distance < 0.1, f"Failed to reach target height with {physics} physics" + assert distance < 0.1, f"Failed to reach target height with {dynamics} dynamics" @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -def test_attitude_interface(physics: Physics): - sim = Sim(physics=physics, control=Control.attitude) +@pytest.mark.parametrize("dynamics", Dynamics) +def test_attitude_interface(dynamics: Dynamics): + sim = Sim(dynamics=dynamics, control=Control.attitude) target_pos = np.array([0.0, 0.0, 1.0]) - jit_state2attitude = jax.jit(parametrize(state2attitude, drone_model="cf2x_L250")) + jit_state2attitude = jax.jit(parametrize(state2attitude, drone="cf2x_L250")) - i_error = np.zeros((1, 1, 3)) + pos_err_i = np.zeros((1, 1, 3)) cmd = np.zeros((1, 1, 13)) cmd[0, 0, 2] = 1.0 # Set z position target to 1.0 for _ in range(int(2 * sim.control_freq)): # Run simulation for 2 seconds pos, vel, quat = sim.data.states.pos, sim.data.states.vel, sim.data.states.quat - rpyt, i_error = jit_state2attitude(pos, quat, vel, cmd, (i_error,), ctrl_freq=100) + rpyt, pos_err_i = jit_state2attitude(pos, quat, vel, cmd, pos_err_i, ctrl_freq=100) sim.attitude_control(rpyt) sim.step(sim.freq // sim.control_freq) # Check if drone maintained hover position dpos = sim.data.states.pos[0, 0] - target_pos distance = np.linalg.norm(dpos) - assert distance < 0.05, f"Failed to maintain hover with {physics} ({dpos})" + assert distance < 0.05, f"Failed to maintain hover with {dynamics} ({dpos})" @pytest.mark.integration def test_rotor_vel_interface(): - sim = Sim(physics=Physics.first_principles, control=Control.rotor_vel) - params = load_params("first_principles", "cf2x_L250") - max_rpm = motor_force2rotor_vel(np.array([params["thrust_max"]]), params["rpm2thrust"])[0] + sim = Sim(dynamics=Dynamics.first_principles, control=Control.rotor_vel) + thrust_max = load_params(state2attitude, sim.drone)["thrust_max"] + rpm2thrust = load_params(force_torque2rotor_vel, sim.drone)["rpm2thrust"] + max_rpm = motor_force2rotor_vel(np.array([thrust_max]), rpm2thrust)[0] sim.data = sim.data.replace( states=sim.data.states.replace(pos=sim.data.states.pos.at[..., 2].set(0.5)) @@ -73,10 +73,10 @@ def test_rotor_vel_interface(): @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -def test_swarm_control(physics: Physics): +@pytest.mark.parametrize("dynamics", Dynamics) +def test_swarm_control(dynamics: Dynamics): n_worlds, n_drones = 2, 3 - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, control=Control.state) + sim = Sim(n_worlds=n_worlds, n_drones=n_drones, dynamics=dynamics, control=Control.state) target_pos = sim.data.states.pos + np.array([0.3, 0.3, 0.3]) cmd = np.zeros((n_worlds, n_drones, 13)) @@ -89,13 +89,12 @@ def test_swarm_control(physics: Physics): @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -def test_yaw_rotation(physics: Physics): - # TODO: Enable yaw rotations once the models are better calibrated - if physics != Physics.first_principles: - pytest.skip(f"Physics mode {physics} currently does not support yaw rotation") +@pytest.mark.parametrize("dynamics", Dynamics) +def test_yaw_rotation(dynamics: Dynamics): + if dynamics != Dynamics.first_principles: + pytest.skip(f"Dynamics mode {dynamics} currently does not support yaw rotation") - sim = Sim(physics=physics, control=Control.state, state_freq=100) + sim = Sim(dynamics=dynamics, control=Control.state, state_freq=100) sim.reset() cmd = np.zeros((sim.n_worlds, sim.n_drones, 13)) diff --git a/tests/integration/test_models.py b/tests/integration/test_models.py index 4d4b26f..53d8896 100644 --- a/tests/integration/test_models.py +++ b/tests/integration/test_models.py @@ -1,13 +1,13 @@ import pytest -from drone_models.drones import available_drones +from crazyflow import available_drones +from crazyflow.dynamics import Dynamics from crazyflow.sim import Sim -from crazyflow.sim.physics import Physics @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -@pytest.mark.parametrize("model", available_drones) -def test_attitude_symbolic(physics: Physics, model: "str"): +@pytest.mark.parametrize("dynamics", Dynamics) +@pytest.mark.parametrize("drone", available_drones) +def test_attitude_symbolic(dynamics: Dynamics, drone: "str"): """Tests if xml files contain syntax errors.""" - Sim(physics=physics, drone_model=model) + Sim(dynamics=dynamics, drone=drone) diff --git a/tests/integration/test_reset.py b/tests/integration/test_reset.py index 67951e8..2118a32 100644 --- a/tests/integration/test_reset.py +++ b/tests/integration/test_reset.py @@ -4,14 +4,14 @@ import crazyflow # noqa: F401, register gymnasium envs from crazyflow.control import Control -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -def test_reset_during_simulation(physics: Physics): +@pytest.mark.parametrize("dynamics", Dynamics) +def test_reset_during_simulation(dynamics: Dynamics): """Test reset behavior during an active simulation.""" - sim = Sim(physics=physics, control=Control.attitude) + sim = Sim(dynamics=dynamics, control=Control.attitude) # Run simulation n_steps = 3 random_cmds = np.random.rand(n_steps, 1, 1, 4) @@ -36,11 +36,11 @@ def test_reset_during_simulation(physics: Physics): @pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -def test_reset_multi_world(physics: Physics): +@pytest.mark.parametrize("dynamics", Dynamics) +def test_reset_multi_world(dynamics: Dynamics): """Test reset behavior with multiple worlds.""" n_worlds, n_drones = 2, 2 - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, control=Control.attitude) + sim = Sim(n_worlds=n_worlds, n_drones=n_drones, dynamics=dynamics, control=Control.attitude) n_steps = 3 random_cmds = np.random.rand(n_steps, n_worlds, n_drones, 4) diff --git a/tests/integration/test_symbolic.py b/tests/integration/test_symbolic.py deleted file mode 100644 index cac9101..0000000 --- a/tests/integration/test_symbolic.py +++ /dev/null @@ -1,71 +0,0 @@ -import casadi as cs -import numpy as np -import pytest -from numpy.typing import NDArray - -from crazyflow.sim import Sim -from crazyflow.sim.data import SimState -from crazyflow.sim.physics import Physics -from crazyflow.sim.pipeline import remove_fn -from crazyflow.sim.symbolic import symbolic_from_sim - - -def sim_state2symbolic_state(state: SimState) -> NDArray[np.float32]: - """Convert the simulation state to the symbolic state vector.""" - return np.concat([state.pos, state.quat, state.vel, state.ang_vel], axis=-1)[0, 0][..., None] - - -@pytest.mark.integration -@pytest.mark.parametrize("physics", Physics) -@pytest.mark.parametrize("freq", [500, 1000]) -def test_attitude_symbolic(physics: Physics, freq: int): - if physics in (Physics.so_rpy_rotor, Physics.so_rpy_rotor_drag): - pytest.skip(f"Physics mode {physics} not yet implemented") - - sim = Sim(physics=physics, freq=freq) - remove_fn(sim.step_pipeline, "clip_floor_pos") # Remove clip floor from step pipeline - X_dot, X, U, Y = symbolic_from_sim(sim) - fd = cs.integrator("fd", "cvodes", {"x": X, "p": U, "ode": X_dot}, 0, 1 / freq) - - x0 = sim_state2symbolic_state(sim.data.states) - - # Simulate with both models for 0.5 seconds - t_end = 0.5 - dt = 1 / sim.freq - steps = int(t_end / dt) - - # Track states over time - x_sym_log = [] - x_sim_log = [] - - # Initialize logs with initial state - x_sym = x0.copy() - x_sym_log.append(x_sym) - x_sim = x0.copy() - x_sim_log.append(x_sim) - - u_low = np.array([-np.pi, -np.pi, -np.pi, 0.3]).reshape(4, 1) - u_high = np.array([np.pi, np.pi, np.pi, 0.5]).reshape(4, 1) - rng = np.random.default_rng(seed=42) - - # Run simulation - for _ in range(steps): - u_rand = (rng.random(4)[..., None] * (u_high - u_low) + u_low).astype(np.float32) - assert x_sym.shape == (13, 1) - assert u_rand.shape == (4, 1) - # Simulate with symbolic model - x_sym = fd(x0=x_sym, p=u_rand)["xf"].full() - x_sym_log.append(x_sym) - # Simulate with attitude controller - sim.attitude_control(u_rand.reshape(1, 1, 4)) - sim.step(sim.freq // sim.control_freq) - x_sim_log.append(sim_state2symbolic_state(sim.data.states)) - - # Convert logs to arrays. Do not record the rpy rates (deviate easily). - x_sym_log = np.array(x_sym_log)[..., :-3] - x_sim_log = np.array(x_sim_log)[..., :-3] - - # Check if states match throughout simulation - err_msg = "Symbolic and simulation prediction do not match approximately" - assert np.allclose(x_sym_log, x_sim_log, rtol=1e-2, atol=1e-2), err_msg - sim.close() diff --git a/tests/unit/control/__init__.py b/tests/unit/control/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/control/test_core.py b/tests/unit/control/test_core.py new file mode 100644 index 0000000..c8e7743 --- /dev/null +++ b/tests/unit/control/test_core.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import inspect +from typing import Any, Callable + +import array_api_strict +import pytest + +from crazyflow.control.core import load_params, parametrize +from crazyflow.control.mellinger import ( + attitude2force_torque, + force_torque2rotor_vel, + state2attitude, +) +from crazyflow.drones import available_drones + +_MELLINGER_FNS = [state2attitude, attitude2force_torque, force_torque2rotor_vel] + + +@pytest.mark.unit +@pytest.mark.parametrize("fn", _MELLINGER_FNS, ids=lambda fn: fn.__name__) +@pytest.mark.parametrize("drone", available_drones) +def test_load_params_keys(fn: Callable[..., Any], drone: str) -> None: + params = load_params(fn, drone) + fn_params = inspect.signature(fn).parameters + fn_kwargs = {k for k, v in fn_params.items() if v.kind == inspect.Parameter.KEYWORD_ONLY} + assert fn_kwargs <= set(params.keys()), f"Missing keys: {fn_kwargs - set(params.keys())}" + + +@pytest.mark.unit +def test_load_params_unknown_drone() -> None: + with pytest.raises(KeyError, match="nonexistent_drone"): + load_params(state2attitude, "nonexistent_drone") + + +@pytest.mark.unit +def test_parametrize_unknown_drone() -> None: + with pytest.raises(KeyError): + parametrize(state2attitude, "nonexistent_drone") + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_parametrize_xp_namespace(drone: str) -> None: + controller = parametrize(state2attitude, drone, xp=array_api_strict) + xp_array_type = type(array_api_strict.asarray(0.0)) + assert all(isinstance(v, xp_array_type) for v in controller.keywords.values()) diff --git a/tests/unit/control/test_mellinger.py b/tests/unit/control/test_mellinger.py new file mode 100644 index 0000000..0876cc7 --- /dev/null +++ b/tests/unit/control/test_mellinger.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from crazyflow.control import parametrize +from crazyflow.control.core import load_params +from crazyflow.control.mellinger import ( + attitude2force_torque, + force_torque2rotor_vel, + state2attitude, +) +from crazyflow.drones import available_drones + +if TYPE_CHECKING: + from crazyflow._typing import Array # To be changed to array_api_typing later + + +def create_rnd_states(shape: tuple[int, ...] = ()) -> tuple[Array, Array, Array, Array]: + x = np.random.randn(*shape, 3 + 4 + 3 + 3) + return x[..., :3], x[..., 3:7], x[..., 7:10], x[..., 10:13] + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_state2attitude(drone: str) -> None: + controller = parametrize(state2attitude, drone) + # Single input + pos, quat, vel, ang_vel = create_rnd_states() + rpyt, pos_err_i = controller(pos, quat, vel, np.ones(13), ctrl_freq=100) + assert rpyt.shape == (4,) + assert pos_err_i.shape == (3,) + # Batch input + pos, quat, vel, ang_vel = create_rnd_states((5, 4)) + rpyt, pos_err_i = controller(pos, quat, vel, np.ones((5, 4, 13)), ctrl_freq=100) + assert rpyt.shape == (5, 4, 4) + assert pos_err_i.shape == (5, 4, 3) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_attitude2force_torque(drone: str) -> None: + controller = parametrize(attitude2force_torque, drone) + # Single input + pos, quat, vel, ang_vel = create_rnd_states() + rpyt_cmd = np.array([0.1, 0.1, 0.1, 1.0]) # roll, pitch, yaw, thrust command + force_des, torque_des, r_int_error = controller(quat, ang_vel, rpyt_cmd) + assert force_des.shape == (1,) + assert torque_des.shape == (3,) + assert r_int_error.shape == (3,) + # Batch input + pos, quat, vel, ang_vel = create_rnd_states((5, 4)) + rpyt_cmd = np.random.randn(5, 4, 4) + rpyt_cmd[..., 3] = np.abs(rpyt_cmd[..., 3]) # Ensure positive thrust + force_des, torque_des, r_int_error = controller(quat, ang_vel, rpyt_cmd) + assert force_des.shape == (5, 4, 1) + assert torque_des.shape == (5, 4, 3) + assert r_int_error.shape == (5, 4, 3) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_force_torque2rotor_vel(drone: str) -> None: + controller = parametrize(force_torque2rotor_vel, drone) + # Single input + force = np.array([1.0]) + torque = np.array([0.1, 0.1, 0.1]) + rotor_vel = controller(force, torque) + assert rotor_vel.shape == (4,) + # Batch input + force = np.ones((5, 4, 1)) + torque = np.random.randn(5, 4, 3) * 0.1 + rotor_vel = controller(force, torque) + assert rotor_vel.shape == (5, 4, 4) + + +# Correctness / physics + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_state2attitude_at_setpoint(drone: str) -> None: + # At setpoint with identity orientation and zero acc, RPY command should be + # [0, 0, 0] and thrust must be positive (hovering against gravity). + controller = parametrize(state2attitude, drone) + pos = np.zeros(3) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + vel = np.zeros(3) + cmd = np.zeros(13) # setpoint at origin, zero vel/acc, yaw=0 + rpyt, _ = controller(pos, quat, vel, cmd) + assert np.allclose(rpyt[:3], 0.0, atol=1e-6), f"RPY at setpoint should be ~0, got {rpyt[:3]}" + assert rpyt[3] > 0.0, "Hovering thrust must be positive" + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_state2attitude_integral_error_accumulation(drone: str) -> None: + # A constant position error must cause the integral error to accumulate + # linearly until it would exceed int_err_max (clipped by the controller). + controller = parametrize(state2attitude, drone) + params = load_params(state2attitude, drone) + pos = np.zeros(3) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + vel = np.zeros(3) + cmd = np.zeros(13) + cmd[0] = 1.0 # 1 m setpoint error in x + ctrl_freq = 100.0 + dt = 1.0 / ctrl_freq + steps = 5 + + pos_err_i = None + for _ in range(steps): + _, pos_err_i = controller(pos, quat, vel, cmd, pos_err_i=pos_err_i, ctrl_freq=ctrl_freq) + + expected = np.clip( + np.array([steps * dt, 0.0, 0.0]), -params["int_err_max"], params["int_err_max"] + ) + assert np.allclose(pos_err_i, expected, atol=1e-6) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_attitude2force_torque_at_setpoint(drone: str) -> None: + # Identity orientation commanded → zero attitude error → zero corrective torque. + controller = parametrize(attitude2force_torque, drone) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + ang_vel = np.zeros(3) + cmd = np.array([0.0, 0.0, 0.0, 0.5]) # RPY=0, positive thrust + force_des, torque_des, _ = controller(quat, ang_vel, cmd) + assert np.allclose(torque_des, 0.0, atol=1e-6), ( + f"Torque at setpoint should be ~0, got {torque_des}" + ) + assert force_des[0] > 0.0, "Force must be positive for positive thrust command" + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_attitude2force_torque_zero_thrust(drone: str): + # Zero thrust command → firmware zeros torque; outputs are all zero. + controller = parametrize(attitude2force_torque, drone) + quat = np.array([0.0, 0.0, 0.0, 1.0]) + ang_vel = np.zeros(3) + cmd = np.array([0.1, 0.1, 0.1, 0.0]) # non-zero RPY but zero thrust + force_des, torque_des, _ = controller(quat, ang_vel, cmd) + assert np.allclose(force_des, 0.0, atol=1e-6) + assert np.allclose(torque_des, 0.0, atol=1e-6) + + +# Batch consistency (batch result == sequential result) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_state2attitude_batch_consistency(drone: str): + controller = parametrize(state2attitude, drone) + batch = (3, 2) + pos, quat, vel, _ = create_rnd_states(batch) + cmd = np.random.randn(*batch, 13) + rpyt_batch, err_batch = controller(pos, quat, vel, cmd) + for i in range(batch[0]): + for j in range(batch[1]): + rpyt_s, err_s = controller(pos[i, j], quat[i, j], vel[i, j], cmd[i, j]) + assert np.allclose(rpyt_batch[i, j], rpyt_s, atol=1e-5) + assert np.allclose(err_batch[i, j], err_s, atol=1e-5) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_attitude2force_torque_batch_consistency(drone: str): + controller = parametrize(attitude2force_torque, drone) + batch = (3, 2) + _, quat, _, ang_vel = create_rnd_states(batch) + cmd = np.random.randn(*batch, 4) + cmd[..., 3] = np.abs(cmd[..., 3]) + force_batch, torque_batch, err_batch = controller(quat, ang_vel, cmd) + for i in range(batch[0]): + for j in range(batch[1]): + force_s, torque_s, err_s = controller(quat[i, j], ang_vel[i, j], cmd[i, j]) + assert np.allclose(force_batch[i, j], force_s, atol=1e-5) + assert np.allclose(torque_batch[i, j], torque_s, atol=1e-5) + assert np.allclose(err_batch[i, j], err_s, atol=1e-5) + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_force_torque2rotor_vel_batch_consistency(drone: str): + controller = parametrize(force_torque2rotor_vel, drone) + batch = (3, 2) + force = np.abs(np.random.randn(*batch, 1)) * 0.05 + 0.05 + torque = np.random.randn(*batch, 3) * 0.001 + rpm_batch = controller(force, torque) + for i in range(batch[0]): + for j in range(batch[1]): + rpm_s = controller(force[i, j], torque[i, j]) + assert np.allclose(rpm_batch[i, j], rpm_s, atol=1e-5) + + +# Symmetric force check + + +@pytest.mark.unit +@pytest.mark.parametrize("drone", available_drones) +def test_force_torque2rotor_vel_symmetric(drone: str): + # Pure vertical force with zero torque → X-frame symmetry → all 4 RPMs equal. + controller = parametrize(force_torque2rotor_vel, drone) + force = np.array([0.2]) # total thrust, split equally across 4 motors + torque = np.zeros(3) + rotor_vel = controller(force, torque) + assert np.allclose(rotor_vel, rotor_vel[0], rtol=1e-5), f"RPMs not equal: {rotor_vel}" diff --git a/tests/unit/control/test_transform.py b/tests/unit/control/test_transform.py new file mode 100644 index 0000000..4925a8a --- /dev/null +++ b/tests/unit/control/test_transform.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np +import pytest + +from crazyflow.control.transform import force2pwm, motor_force2rotor_vel, pwm2force +from crazyflow.drones import load_params + + +@pytest.fixture(scope="module") +def core_params() -> dict[str, Any]: + return {k: np.asarray(v) for k, v in load_params("cf2x_L250").items()} + + +@pytest.mark.unit +def test_force2pwm_pwm2force_roundtrip(core_params: dict[str, Any]) -> None: + thrust_max = float(core_params["thrust_max"]) + pwm_max = float(core_params["pwm_max"]) + forces = np.array([0.0, thrust_max * 0.25, thrust_max * 0.5, thrust_max]) + assert np.allclose( + pwm2force(force2pwm(forces, thrust_max, pwm_max), thrust_max, pwm_max), forces + ) + + +@pytest.mark.unit +def test_force2pwm_boundary(core_params: dict[str, Any]) -> None: + thrust_max = float(core_params["thrust_max"]) + pwm_max = float(core_params["pwm_max"]) + assert force2pwm(0.0, thrust_max, pwm_max) == pytest.approx(0.0) + assert force2pwm(thrust_max, thrust_max, pwm_max) == pytest.approx(pwm_max) + + +@pytest.mark.unit +def test_motor_force2rotor_vel_shape(core_params: dict[str, Any]) -> None: + rpm2thrust = core_params["rpm2thrust"] + assert motor_force2rotor_vel(np.full(4, 0.05), rpm2thrust).shape == (4,) + assert motor_force2rotor_vel(np.full((3, 2, 4), 0.05), rpm2thrust).shape == (3, 2, 4) + + +@pytest.mark.unit +def test_motor_force2rotor_vel_positive(core_params: dict[str, Any]) -> None: + rpm2thrust = core_params["rpm2thrust"] + forces = np.linspace(0.02, 0.12, 10) + assert np.all(motor_force2rotor_vel(forces, rpm2thrust) > 0) diff --git a/tests/unit/dynamics/__init__.py b/tests/unit/dynamics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/dynamics/test_dynamics.py b/tests/unit/dynamics/test_dynamics.py new file mode 100644 index 0000000..ae128ea --- /dev/null +++ b/tests/unit/dynamics/test_dynamics.py @@ -0,0 +1,242 @@ +"""Tests of the numeric dynamics.""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Callable + +import array_api_strict as xp +import casadi as cs +import jax +import jax.numpy as jp +import numpy as np +import pytest +from array_api_compat import device as xp_device + +from crazyflow.drones import available_drones +from crazyflow.dynamics import available_dynamics, dynamics_features +from crazyflow.dynamics.core import parametrize + +if TYPE_CHECKING: + from crazyflow._typing import Array # To be changed to array_api_typing later + + +@pytest.fixture(autouse=True) +def _enable_x64(): + """Run only this module in float64 so jax matches numpy precision, then restore the default. + + jax_enable_x64 is a global flag, so we scope it to this file's tests to keep every other test + running in float32. + """ + prev = jax.config.jax_enable_x64 + jax.config.update("jax_enable_x64", True) + try: + yield + finally: + jax.config.update("jax_enable_x64", prev) + + +def create_rnd_states( + shape: tuple[int, ...] = (), +) -> tuple[Array, Array, Array, Array, Array, Array, Array]: + x = np.random.randn(*shape, 3 + 4 + 3 + 3 + 4 + 3 + 3) + pos = xp.asarray(x[..., :3]) + quat = xp.asarray(x[..., 3:7]) + vel = xp.asarray(x[..., 7:10]) + ang_vel = xp.asarray(x[..., 10:13]) + rotor_vel = xp.abs(xp.asarray(x[..., 13:17])) # Rotor velocities must be positive + dist_f = xp.asarray(x[..., 17:20]) + dist_t = xp.asarray(x[..., 20:23]) + return pos, quat, vel, ang_vel, rotor_vel, dist_f, dist_t + + +def create_rnd_commands(shape: tuple[int, ...] = (), dim: int = 4) -> Array: + """Creates N random inputs with size dim.""" + return xp.abs(xp.asarray(np.random.randn(*shape, dim))) # Motor forces must be positive + + +def make_inputs( + dynamics: Callable, + *, + batch: tuple[int, ...] = (), + rotor_vel: bool = True, + ext_wrench: bool = False, +) -> dict[str, Array]: + """Build random inputs for a parametrized dynamics function, to splat as ``dynamics(**inp)``. + + Only the requested inputs are included, so the dict doubles as the call signature: dynamics + without a rotor_vel parameter (so_rpy) simply never get one, and omitted optional inputs fall + back to the function defaults. + + Args: + dynamics: The (parametrized) dynamics function. + batch: Batch shape of the inputs. + rotor_vel: Whether to provide rotor_vel (ignored for dynamics without rotor dynamics). + Set to False to exercise the commanded-rotor-velocity fallback. + ext_wrench: Whether to provide external force/torque disturbances. + """ + pos, quat, vel, ang_vel, rv, dist_f, dist_t = create_rnd_states(batch) + cmd = create_rnd_commands(batch) + inp = {"pos": pos, "quat": quat, "vel": vel, "ang_vel": ang_vel, "cmd": cmd} + if rotor_vel and dynamics_features(dynamics)["rotor_dynamics"]: + inp["rotor_vel"] = rv + if ext_wrench: + inp["dist_f"], inp["dist_t"] = dist_f, dist_t + return inp + + +def state_vector(inp: dict) -> Array: + """Stacked state vector matching the symbolic state X (present inputs in canonical order).""" + order = ("pos", "quat", "vel", "ang_vel", "rotor_vel", "dist_f", "dist_t") + return xp.concat([inp[k] for k in order if k in inp], axis=-1) + + +def symbolic_flags(dynamics: Callable, dist: bool = False) -> dict[str, bool]: + """Build the symbolic_dynamics flags, gating model_rotor_vel on the rotor dynamics feature.""" + flags = {} + if dynamics_features(dynamics)["rotor_dynamics"]: + flags["model_rotor_vel"] = True + if dist: + flags["model_dist_f"] = flags["model_dist_t"] = True + return flags + + +def assert_array_meta(x: Array | None, y: Array | None, name: str | None = None): + """Assert the output is on the correct device, has the correct type and shape.""" + if x is None and y is None: + return + prefix = "" if name is None else f"{name}: " + assert isinstance(x, type(y)), ( + f"{prefix}Output type {type(x)} does not match expected {type(y)}" + ) + assert xp_device(x) == xp_device(y), ( + f"{prefix}Output device {xp_device(x)} does not match expected {xp_device(y)}" + ) + assert x.shape == y.shape, f"{prefix}Output shape {x.shape} does not match expected {y.shape}" + assert np.all(np.isnan(x) == np.isnan(y)), f"{prefix}Derivative of non-nan values are NaN" + + +def assert_shapes(dynamics: Callable, inp: dict): + """Assert the dynamics output has the correct type, device and shape for each derivative.""" + out = dynamics(**inp) + names = ["dpos", "dquat", "dvel", "dang_vel"] + expected = [inp["pos"], inp["quat"], inp["vel"], inp["ang_vel"]] + if dynamics_features(dynamics)["rotor_dynamics"]: + names.append("drotor_vel") + expected.append(inp.get("rotor_vel")) + for name, dx, x in zip(names, out, expected, strict=True): + assert_array_meta(dx, x, name=name) + + +def check_shapes(dynamics: Callable, batch: tuple[int, ...] = ()): + """Check output shapes with/without external wrench, and the rotor_vel fallback warning.""" + assert_shapes(dynamics, make_inputs(dynamics, batch=batch)) + assert_shapes(dynamics, make_inputs(dynamics, batch=batch, ext_wrench=True)) + if not dynamics_features(dynamics)["rotor_dynamics"]: + return + for ext_wrench in (False, True): + inp = make_inputs(dynamics, batch=batch, rotor_vel=False, ext_wrench=ext_wrench) + with pytest.warns(UserWarning, match="Rotor velocity not provided"): + assert_shapes(dynamics, inp) + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +def test_dynamics_features(dynamics_name: str, dynamics: Callable): + """Tests if the dynamics features are correctly set.""" + assert hasattr(dynamics, "__dynamics_features__"), ( + f"Dynamics function {dynamics_name} does not have __dynamics_features__ attribute" + ) + features = dynamics_features(dynamics) + assert isinstance(features, dict), ( + f"dynamics features should be a dict, got {type(features)} for {dynamics_name}" + ) + assert "rotor_dynamics" in features, ( + f"dynamics features should contain 'rotor_dynamics' key for {dynamics_name}" + ) + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +@pytest.mark.parametrize("drone", available_drones) +def test_dynamics_shapes(dynamics_name: str, dynamics: Callable, drone: str): + check_shapes(parametrize(dynamics, drone)) + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +@pytest.mark.parametrize("drone", available_drones) +def test_dynamics_shapes_batched(dynamics_name: str, dynamics: Callable, drone: str): + dynamics = parametrize(dynamics, drone, xp=xp) + shape = (10, 5) + check_shapes(dynamics, batch=shape) + # Batched parameters + dynamics.keywords["J"] = xp.tile(dynamics.keywords["J"][None, None, ...], shape + (1, 1)) + dynamics.keywords["J_inv"] = xp.tile( + dynamics.keywords["J_inv"][None, None, ...], shape + (1, 1) + ) + check_shapes(dynamics, batch=shape) + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +@pytest.mark.parametrize("drone", available_drones) +@pytest.mark.parametrize("ext_wrench", [False, True]) +def test_symbolic_dynamics(dynamics_name: str, dynamics: Callable, drone: str, ext_wrench: bool): + """Tests if the symbolic and numeric dynamics produce the same output.""" + symbolic_dynamics = getattr(sys.modules[dynamics.__module__], "symbolic_dynamics") + symbolic_dynamics = parametrize(symbolic_dynamics, drone) + dynamics = parametrize(dynamics, drone) + inp = make_inputs(dynamics, batch=(10, 5), ext_wrench=ext_wrench) + + X_dot, X, U, _ = symbolic_dynamics(**symbolic_flags(dynamics, dist=ext_wrench)) + symbolic2numeric = cs.Function(dynamics_name, [X, U], [X_dot]) + + for i in np.ndindex(np.shape(inp["pos"])[:-1]): # casadi only supports non batched calls + inp_i = {k: v[i + (...,)] for k, v in inp.items()} + x_dot = xp.concat([x for x in dynamics(**inp_i) if x is not None], axis=-1) + X, U = np.asarray(state_vector(inp_i)), np.asarray(inp_i["cmd"]) + x_dot_symbolic2numeric = xp.squeeze(xp.asarray(symbolic2numeric(X, U)), axis=-1) + assert np.allclose(x_dot, x_dot_symbolic2numeric), ( + "Symbolic and numeric dynamics have different output" + ) + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +@pytest.mark.parametrize("drone", available_drones) +def test_compare_batched_non_batched(dynamics_name: str, dynamics: Callable, drone: str): + """Tests if batching works and if the results are identical to the non-batched version.""" + dynamics = parametrize(dynamics, drone) + inp = make_inputs(dynamics, batch=(10, 5)) + + x_dot_batched = xp.concat([x for x in dynamics(**inp) if x is not None], axis=-1) + for i in np.ndindex(np.shape(inp["pos"])[:-1]): + out = dynamics(**{k: v[i + (...,)] for k, v in inp.items()}) + x_dot = xp.concat([x for x in out if x is not None], axis=-1) + assert np.allclose(x_dot_batched[i + (...,)], x_dot, atol=1e-5), ( + "Non-batched and batched results are not the same" + ) + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +@pytest.mark.parametrize("drone", available_drones) +def test_numeric_jit(dynamics_name: str, dynamics: Callable, drone: str): + """Tests if the dynamics are jitable and if the results are identical to the array API ones.""" + dynamics = parametrize(dynamics, drone) + inp = make_inputs(dynamics, batch=(10, 5)) + xp_dot = dynamics(**inp) + jp_dot = jax.jit(dynamics)(**{k: jp.asarray(np.asarray(v)) for k, v in inp.items()}) + + assert isinstance(jp_dot[0], jp.ndarray), "Results are not jax arrays" + xp_dot = xp.concat([x for x in xp_dot if x is not None], axis=-1) + jp_dot = jp.concat([x for x in jp_dot if x is not None], axis=-1) + assert np.allclose(xp_dot, jp_dot), "numpy and jax results differ" + + +# TODO test if external wrench gets applied properly. But how to test it? +# -> maybe apply and predict based on mass how much higher the acceleration should be +# same for torque +@pytest.mark.unit +def test_external_wrench(): ... diff --git a/tests/unit/dynamics/test_parametrization.py b/tests/unit/dynamics/test_parametrization.py new file mode 100644 index 0000000..1472a22 --- /dev/null +++ b/tests/unit/dynamics/test_parametrization.py @@ -0,0 +1,19 @@ +"""Tests of the parametrization of the dynamics.""" + +from __future__ import annotations + +from typing import Callable + +import pytest + +from crazyflow.drones import available_drones +from crazyflow.dynamics import available_dynamics +from crazyflow.dynamics.core import parametrize + + +@pytest.mark.unit +@pytest.mark.parametrize("dynamics_name, dynamics", available_dynamics.items()) +@pytest.mark.parametrize("drone", available_drones) +def test_dynamics_parametrization(dynamics_name: str, dynamics: Callable, drone: str): + """Check that we can parametrize all available dynamics with all drones.""" + parametrize(dynamics, drone) diff --git a/tests/unit/dynamics/test_preprocessing.py b/tests/unit/dynamics/test_preprocessing.py new file mode 100644 index 0000000..e83d8aa --- /dev/null +++ b/tests/unit/dynamics/test_preprocessing.py @@ -0,0 +1,60 @@ +"""Tests for identification utils.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from crazyflow.dynamics.utils.data_utils import derivatives_svf, preprocessing + + +@pytest.mark.unit +def test_preprocessing(): + """Test preprocessing function.""" + time = np.linspace(0.0, 1.0, 100) + pos = np.stack((np.cos(2 * np.pi * time), np.sin(2 * np.pi * time), time), axis=-1) + quat = np.zeros((100, 4)) + quat[:, 3] = 1.0 # No rotation + cmd_rpy = np.zeros((100, 3)) + cmd_f = np.ones(100) * 0.03 + data = { + "time": np.linspace(0.0, 1.0, 100), + "pos": pos, + "quat": quat, + "cmd_rpy": cmd_rpy, + "cmd_f": cmd_f, + } + + data_processed = preprocessing(data) + data_processed["rpy"] + + +@pytest.mark.unit +def test_derivatives_svf(): + """Test preprocessing function.""" + time = np.linspace(0.0, 1.0, 100) + pos = np.stack((np.cos(2 * np.pi * time), np.sin(2 * np.pi * time), time), axis=-1) + quat = np.zeros((100, 4)) + quat[:, 3] = 1.0 # No rotation + rpy = np.zeros((100, 3)) + cmd_rpy = np.zeros((100, 3)) + cmd_f = np.ones(100) * 0.03 + data = { + "time": np.linspace(0.0, 1.0, 100), + "pos": pos, + "quat": quat, + "rpy": rpy, + "cmd_rpy": cmd_rpy, + "cmd_f": cmd_f, + } + + data = derivatives_svf(data) + # Needed for translational sysid + data["SVF_pos"] + data["SVF_vel"] + data["SVF_acc"] + data["SVF_quat"] + data["SVF_cmd_f"] + # Needed for rotational sysid + data["SVF_rpy"] + data["SVF_cmd_rpy"] diff --git a/tests/unit/dynamics/test_rotation.py b/tests/unit/dynamics/test_rotation.py new file mode 100644 index 0000000..da46999 --- /dev/null +++ b/tests/unit/dynamics/test_rotation.py @@ -0,0 +1,223 @@ +"""Testing the selfimplemented rotations against scipy rotations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import array_api_strict as xp +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +import crazyflow.dynamics.utils.rotation as rotation + +if TYPE_CHECKING: + from numpy.typing import NDArray + +tol = 1e-6 # Since Jax by default works with 32 bit, the precision is worse + + +def create_uniform_quats(N: int = 1000, scale: float = 10) -> NDArray: + """Creates an (n, 4) list with random quaternions.""" + # larger range because the function should be able to handle wrong length quaternions + return np.random.uniform(-scale, scale, size=(N, 4)) + + +def create_uniform_ang_vel(N: int = 1000, scale: float = 10) -> NDArray: + """Creates an (n, 4) list with random quaternions.""" + # larger range because the function should be able to handle wrong length quaternions + return np.random.uniform(-scale, scale, size=(N, 3)) + + +@pytest.mark.unit +def test_ang_vel2rpy_rates_two_way(): + quats = xp.asarray(create_uniform_quats()) + ang_vels = xp.asarray(create_uniform_ang_vel()) + + rpy_rates_two_way = rotation.ang_vel2rpy_rates(quats, ang_vels) + ang_vels_two_way = rotation.rpy_rates2ang_vel(quats, rpy_rates_two_way) + assert np.allclose(ang_vels, ang_vels_two_way), "Two way transform results are off." + + +@pytest.mark.unit +def test_ang_vel2rpy_rates_batching(): + quats = xp.asarray(create_uniform_quats()) + ang_vels = xp.asarray(create_uniform_ang_vel()) + + # Calculate batched version + rpy_rates_batched = rotation.ang_vel2rpy_rates(quats, ang_vels) + + # Compare to non-batched version + for i in range(ang_vels.shape[0]): + rpy_rates_non_batched = rotation.ang_vel2rpy_rates(quats[i, ...], ang_vels[i, ...]) + assert np.allclose(rpy_rates_non_batched, rpy_rates_batched[i, ...]), ( + "Batched and non-batched results differ." + ) + + +@pytest.mark.unit +def test_rpy_rates2ang_vel_batching(): + quats = xp.asarray(create_uniform_quats()) + rpy_rates = xp.asarray(create_uniform_ang_vel()) + + # Calculate batched version + ang_vel_batched = rotation.rpy_rates2ang_vel(quats, rpy_rates) + + # Compare to non-batched version + for i in range(rpy_rates.shape[0]): + ang_vel_non_batched = rotation.rpy_rates2ang_vel(quats[i, ...], rpy_rates[i, ...]) + assert np.allclose(ang_vel_non_batched, ang_vel_batched[i, ...]), ( + "Batched and non-batched results differ." + ) + + +@pytest.mark.unit +def test_ang_vel2rpy_rates_symbolic(): + quats = np.array(create_uniform_quats()) + ang_vels = np.array(create_uniform_ang_vel()) + + # Calculate numeric version + rpy_rates = rotation.ang_vel2rpy_rates(quats, ang_vels) + + # Compare to casadi implementation + for i in range(len(ang_vels)): + rpy_rates_cs = np.array(rotation.cs_ang_vel2rpy_rates(quats[i], ang_vels[i])).flatten() + assert np.allclose(rpy_rates_cs, rpy_rates[i]), "Symbolic and numeric results differ." + + +@pytest.mark.unit +def test_rpy_rates2ang_vel_symbolic(): + quats = np.array(create_uniform_quats()) + rpy_rates = np.array(create_uniform_ang_vel()) + + # Calculate numeric version + ang_vels = rotation.rpy_rates2ang_vel(quats, rpy_rates) + + # Compare to casadi implementation + for i in range(len(rpy_rates)): + ang_vel_cs = np.array(rotation.cs_rpy_rates2ang_vel(quats[i], rpy_rates[i])).flatten() + assert np.allclose(ang_vel_cs, ang_vels[i]), "Symbolic and numeric results differ." + + +@pytest.mark.unit +def test_ang_vel_deriv2rpy_rates_deriv_two_way(): + quats = xp.asarray(create_uniform_quats()) + ang_vels = xp.asarray(create_uniform_ang_vel()) + ang_vels_deriv = xp.asarray(create_uniform_ang_vel()) + rpy_rates = rotation.ang_vel2rpy_rates(quats, ang_vels) + + rpy_rates_deriv_two_way = rotation.ang_vel_deriv2rpy_rates_deriv( + quats, ang_vels, ang_vels_deriv + ) + ang_vels_deriv_two_way = rotation.rpy_rates_deriv2ang_vel_deriv( + quats, rpy_rates, rpy_rates_deriv_two_way + ) + assert np.allclose(ang_vels_deriv, ang_vels_deriv_two_way), "Two way transform results are off." + + +@pytest.mark.unit +def test_ang_vel_deriv2rpy_rates_deriv_batching(): + quats = xp.asarray(create_uniform_quats()) + ang_vels = xp.asarray(create_uniform_ang_vel()) + ang_vels_deriv = xp.asarray(create_uniform_ang_vel()) + + # Calculate batched version + rpy_rates_deriv_batched = rotation.ang_vel_deriv2rpy_rates_deriv( + quats, ang_vels, ang_vels_deriv + ) + + # Compare to non-batched version + for i in range(ang_vels.shape[0]): + rpy_rates_deriv_non_batched = rotation.ang_vel_deriv2rpy_rates_deriv( + quats[i, ...], ang_vels[i, ...], ang_vels_deriv[i, ...] + ) + assert np.allclose(rpy_rates_deriv_non_batched, rpy_rates_deriv_batched[i, ...]), ( + "Batched and non-batched results differ." + ) + + +@pytest.mark.unit +def test_rpy_rates_deriv2ang_vel_deriv_batching(): + quats = np.array(create_uniform_quats()) + rpy_rates = np.array(create_uniform_ang_vel()) + rpy_rates_deriv = np.array(create_uniform_ang_vel()) + + # Calculate batched version + ang_vels_deriv_batched = rotation.rpy_rates_deriv2ang_vel_deriv( + quats, rpy_rates, rpy_rates_deriv + ) + + # Compare to non-batched version + for i in range(rpy_rates.shape[0]): + ang_vels_deriv_non_batched = rotation.rpy_rates_deriv2ang_vel_deriv( + quats[i, ...], rpy_rates[i, ...], rpy_rates_deriv[i, ...] + ) + assert np.allclose(ang_vels_deriv_non_batched, ang_vels_deriv_batched[i, ...]), ( + "Batched and non-batched results differ." + ) + + +@pytest.mark.unit +def test_ang_vel_deriv2rpy_rates_deriv_symbolic(): + quats = np.array(create_uniform_quats()) + ang_vels = np.array(create_uniform_ang_vel()) + ang_vels_deriv = np.array(create_uniform_ang_vel()) + + # Calculate batched version + rpy_rates_deriv = rotation.ang_vel_deriv2rpy_rates_deriv(quats, ang_vels, ang_vels_deriv) + + # Compare to casadi implementation + for i in range(ang_vels.shape[0]): + # TODO test against casadi implementation + rpy_rates_deriv_cs = np.array( + rotation.cs_ang_vel_deriv2rpy_rates_deriv(quats[i], ang_vels[i], ang_vels_deriv[i]) + ).flatten() + assert np.allclose(rpy_rates_deriv_cs, rpy_rates_deriv[i]), ( + "Symbolic and numeric results differ." + ) + + +@pytest.mark.unit +def test_rpy_rates_deriv2ang_vel_deriv_symbolic(): + quats = np.array(create_uniform_quats()) + rpy_rates = np.array(create_uniform_ang_vel()) + rpy_rates_deriv = np.array(create_uniform_ang_vel()) + + # Calculate batched version + ang_vels_deriv = rotation.rpy_rates_deriv2ang_vel_deriv(quats, rpy_rates, rpy_rates_deriv) + + # Compare to casadi implementation + for i in range(len(rpy_rates)): + # TODO test against casadi implementation + ang_vels_deriv_cs = np.array( + rotation.cs_rpy_rates_deriv2ang_vel_deriv(quats[i], rpy_rates[i], rpy_rates_deriv[i]) + ).flatten() + assert np.allclose(ang_vels_deriv_cs, ang_vels_deriv[i]), ( + "Symbolic and numeric results differ." + ) + + +# TODO test ang_vel2rpy_rates (and deriv) conversions with jp and np arrays + + +@pytest.mark.unit +def test_quat2matrix_symbolic(): + quats = np.array(create_uniform_quats()) + for i, q in enumerate(quats): + mat_scipy = R.from_quat(q).as_matrix() + + # compare casadi/symbolic implementation to scipy + mat_cs = np.array(rotation.cs_quat2matrix_func(q)) + assert np.allclose(mat_cs, mat_scipy, atol=tol), "Symbolic quat->matrix differs from scipy." + + +@pytest.mark.unit +def test_rpy2matrix_symbolic(): + rpys = np.array(create_uniform_ang_vel()) + for i, rpy in enumerate(rpys): + rpy = rpys[i] + mat_scipy = R.from_euler("xyz", rpy).as_matrix() + + # compare casadi/symbolic implementation to scipy + mat_cs = np.array(rotation.cs_rpy2matrix_func(rpy)) + assert np.allclose(mat_cs, mat_scipy, atol=tol), "Symbolic rpy->matrix differs from scipy." diff --git a/tests/unit/test_gradients.py b/tests/unit/test_gradients.py index 8f2318b..4b73adf 100644 --- a/tests/unit/test_gradients.py +++ b/tests/unit/test_gradients.py @@ -5,16 +5,16 @@ import pytest from jax import Array +from crazyflow.dynamics import Dynamics from crazyflow.sim import Sim from crazyflow.sim.data import Control, SimData -from crazyflow.sim.physics import Physics @pytest.mark.skip(reason="State needs SVD in from_matrix, which is not differentiable.") @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_state_cmd_gradients(physics: Physics): - sim = Sim(physics=physics, control=Control.state, freq=500) +@pytest.mark.parametrize("dynamics", Dynamics) +def test_state_cmd_gradients(dynamics: Dynamics): + sim = Sim(dynamics=dynamics, control=Control.state, freq=500) sim_step = sim._step def step(cmd: Array, data: SimData) -> Array: @@ -34,9 +34,9 @@ def step(cmd: Array, data: SimData) -> Array: @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_attitude_cmd_gradients(physics: Physics): - sim = Sim(physics=physics, control=Control.attitude, freq=500) +@pytest.mark.parametrize("dynamics", Dynamics) +def test_attitude_cmd_gradients(dynamics: Dynamics): + sim = Sim(dynamics=dynamics, control=Control.attitude, freq=500) def step(cmd: Array, data: SimData) -> Array: data = data.replace( @@ -56,7 +56,7 @@ def step(cmd: Array, data: SimData) -> Array: @pytest.mark.unit def test_force_torque_cmd_gradients(): - sim = Sim(physics=Physics.first_principles, control=Control.force_torque, freq=500) + sim = Sim(dynamics=Dynamics.first_principles, control=Control.force_torque, freq=500) def step(cmd: Array, data: SimData) -> Array: data = data.replace( diff --git a/tests/unit/test_render.py b/tests/unit/test_render.py index b1b6c23..7990c7c 100644 --- a/tests/unit/test_render.py +++ b/tests/unit/test_render.py @@ -10,7 +10,7 @@ @pytest.mark.render @skip_if_headless def test_render_camera_selection_from_name(cam_name: str): - sim = Sim(drone_model="cf21B_500", n_drones=2) + sim = Sim(drone="cf21B_500", n_drones=2) cam_id = mujoco.mj_name2id(sim.mj_model, mujoco.mjtObj.mjOBJ_CAMERA, cam_name) sim.render(mode="human", camera=cam_name) viewer_cam = sim.viewer.viewer.cam @@ -24,7 +24,7 @@ def test_render_camera_selection_from_name(cam_name: str): @pytest.mark.render @skip_if_headless def test_render_camera_selection_from_id(cam_id: int): - sim = Sim(drone_model="cf21B_500", n_drones=2) + sim = Sim(drone="cf21B_500", n_drones=2) sim.render(mode="human", camera=cam_id) viewer_cam = sim.viewer.viewer.cam assert viewer_cam.type == mujoco.mjtCamera.mjCAMERA_FIXED, "Camera type was not set to FIXED" @@ -36,7 +36,7 @@ def test_render_camera_selection_from_id(cam_id: int): @pytest.mark.render @skip_if_headless def test_render_free_camera(): - sim = Sim(drone_model="cf21B_500", n_drones=2) + sim = Sim(drone="cf21B_500", n_drones=2) sim.render(mode="human") viewer_cam = sim.viewer.viewer.cam assert viewer_cam.type == mujoco.mjtCamera.mjCAMERA_FREE, "Camera type was not set to FREE" diff --git a/tests/unit/test_sim.py b/tests/unit/test_sim.py index a548ddc..315fe05 100644 --- a/tests/unit/test_sim.py +++ b/tests/unit/test_sim.py @@ -14,7 +14,7 @@ from crazyflow.control import Control from crazyflow.exception import ConfigError -from crazyflow.sim import Physics, Sim +from crazyflow.sim import Dynamics, Sim from crazyflow.sim.data import ControlData, SimData from crazyflow.sim.sim import sync_sim2mjx from crazyflow.sim.visualize import change_material @@ -50,23 +50,23 @@ def array_compare_assert(x: Array, y: Array, value: bool = True, name: str | Non @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) +@pytest.mark.parametrize("dynamics", Dynamics) @pytest.mark.parametrize("control", Control) @pytest.mark.parametrize("n_worlds", [1, 2]) -def test_sim_init(physics: Physics, device: str, control: Control, n_worlds: int): +def test_sim_init(dynamics: Dynamics, device: str, control: Control, n_worlds: int): n_drones = 1 - if physics != Physics.first_principles: + if dynamics != Dynamics.first_principles: if control in (Control.force_torque, Control.rotor_vel): with pytest.raises(ConfigError): - Sim(n_worlds=n_worlds, physics=physics, device=device, control=control) + Sim(n_worlds=n_worlds, dynamics=dynamics, device=device, control=control) return - sim = Sim(n_worlds=n_worlds, physics=physics, device=device, control=control) + sim = Sim(n_worlds=n_worlds, dynamics=dynamics, device=device, control=control) assert sim.n_worlds == n_worlds assert sim.n_drones == n_drones assert sim.device == jax.devices(device)[0] - assert sim.physics == physics + assert sim.dynamics == dynamics # Test state buffer shapes array_meta_assert(sim.data.states.pos, (n_worlds, n_drones, 3), device, "pos") @@ -98,12 +98,12 @@ def test_sim_init(physics: Physics, device: str, control: Control, n_worlds: int @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) +@pytest.mark.parametrize("dynamics", Dynamics) @pytest.mark.parametrize("n_worlds", [1, 2]) @pytest.mark.parametrize("n_drones", [1, 3]) -def test_reset(device: str, physics: Physics, n_worlds: int, n_drones: int): +def test_reset(device: str, dynamics: Dynamics, n_worlds: int, n_drones: int): """Test that reset without mask resets all worlds to default state.""" - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, device=device) + sim = Sim(n_worlds=n_worlds, n_drones=n_drones, dynamics=dynamics, device=device) # Modify states data = sim.data @@ -136,10 +136,10 @@ def test_reset(device: str, physics: Physics, n_worlds: int, n_drones: int): @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_reset_masked(device: str, physics: Physics): +@pytest.mark.parametrize("dynamics", Dynamics) +def test_reset_masked(device: str, dynamics: Dynamics): """Test that reset with mask only resets specified worlds.""" - sim = Sim(n_worlds=2, n_drones=1, physics=physics, device=device) + sim = Sim(n_worlds=2, n_drones=1, dynamics=dynamics, device=device) # Modify states data = sim.data @@ -182,14 +182,16 @@ def test_reset_masked(device: str, physics: Physics): @pytest.mark.unit @pytest.mark.parametrize("n_worlds", [1, 2]) @pytest.mark.parametrize("n_drones", [1, 3]) -@pytest.mark.parametrize("physics", Physics) +@pytest.mark.parametrize("dynamics", Dynamics) @pytest.mark.parametrize("control", Control) -def test_sim_step(n_worlds: int, n_drones: int, physics: Physics, control: Control, device: str): - if physics != Physics.first_principles: +def test_sim_step(n_worlds: int, n_drones: int, dynamics: Dynamics, control: Control, device: str): + if dynamics != Dynamics.first_principles: if control in (Control.force_torque, Control.rotor_vel): - pytest.skip(f"{control} is not supported with non-first-principles physics") + pytest.skip(f"{control} is not supported with non-first-principles dynamics") - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=physics, device=device, control=control) + sim = Sim( + n_worlds=n_worlds, n_drones=n_drones, dynamics=dynamics, device=device, control=control + ) sim.step(2) @@ -290,7 +292,7 @@ def test_render_rgb_array(device: str): @pytest.mark.unit def test_device(device: str): - sim = Sim(n_worlds=2, physics=Physics.so_rpy, device=device) + sim = Sim(n_worlds=2, dynamics=Dynamics.so_rpy, device=device) sim.step() assert sim.data.states.pos.device == jax.devices(device)[0] @@ -299,7 +301,7 @@ def test_device(device: str): @pytest.mark.parametrize("n_worlds", [1, 2]) @pytest.mark.parametrize("n_drones", [1, 3]) def test_sync_shape_consistency(device: str, n_drones: int, n_worlds: int): - sim = Sim(n_worlds=n_worlds, n_drones=n_drones, physics=Physics.so_rpy, device=device) + sim = Sim(n_worlds=n_worlds, n_drones=n_drones, dynamics=Dynamics.so_rpy, device=device) qpos_shape, qvel_shape = sim.mjx_data.qpos.shape, sim.mjx_data.qvel.shape _, mjx_data = sync_sim2mjx(sim.data, sim.mjx_data, sim.mjx_model) assert mjx_data.qpos.shape == qpos_shape, "sync_sim2mjx() should not change qpos shape" @@ -307,11 +309,11 @@ def test_sync_shape_consistency(device: str, n_drones: int, n_worlds: int): @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_control_frequency(physics: Physics): +@pytest.mark.parametrize("dynamics", Dynamics) +def test_control_frequency(dynamics: Dynamics): # Create two sims with different frequencies - sim_500 = Sim(freq=500, physics=physics, control="state") - sim_1000 = Sim(freq=1000, physics=physics, control="state") + sim_500 = Sim(freq=500, dynamics=dynamics, control="state") + sim_1000 = Sim(freq=1000, dynamics=dynamics, control="state") # Set same initial state and controls cmd = np.zeros((1, 1, 13)) # Single world, single drone, state control @@ -369,14 +371,14 @@ def test_seed_reset(): @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_floor_penetration(physics: Physics): +@pytest.mark.parametrize("dynamics", Dynamics) +def test_floor_penetration(dynamics: Dynamics): """Test that drones cannot penetrate the floor (z < 0.01). We don't test for mujoco, as mujoco uses collisions by default and will let the drone bounce on the floor. """ - sim = Sim(physics=physics, control=Control.attitude, freq=500, device="cpu") + sim = Sim(dynamics=dynamics, control=Control.attitude, freq=500, device="cpu") sim.reset() # Command to fall: zero thrust and attitude that points downward attitude_cmd = np.zeros((1, 1, 4)) # [roll, pitch, yaw, thrust] @@ -395,9 +397,9 @@ def test_floor_penetration(physics: Physics): @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_contacts(physics: Physics): - sim = Sim(physics=physics, control=Control.attitude, freq=500, device="cpu") +@pytest.mark.parametrize("dynamics", Dynamics) +def test_contacts(dynamics: Dynamics): + sim = Sim(dynamics=dynamics, control=Control.attitude, freq=500, device="cpu") sim.reset() sim.step(10) # Make sure the drone is on the ground contacts = sim.contacts() @@ -409,7 +411,7 @@ def test_contacts(physics: Physics): @pytest.mark.parametrize("control", Control) def test_data_committed(control: Control, device: str): # Check that the data is committed to the device we chose - sim = Sim(physics=Physics.first_principles, control=control, freq=500, device=device) + sim = Sim(dynamics=Dynamics.first_principles, control=control, freq=500, device=device) def assert_committed(obj0: Array | Any, path: str = "data"): if isinstance(obj0, jnp.ndarray): @@ -434,9 +436,9 @@ def assert_committed(obj0: Array | Any, path: str = "data"): @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_compile(physics: Physics, device: str): - sim = Sim(physics=physics, control=Control.state, freq=500, device=device) +@pytest.mark.parametrize("dynamics", Dynamics) +def test_compile(dynamics: Dynamics, device: str): + sim = Sim(dynamics=dynamics, control=Control.state, freq=500, device=device) # Make sure we don't recompile the step function after the first call sim.step(1) sim.step(1) @@ -445,9 +447,9 @@ def test_compile(physics: Physics, device: str): @pytest.mark.unit -@pytest.mark.parametrize("physics", Physics) -def test_scan_results(physics: Physics): - sim = Sim(n_worlds=2, n_drones=3, physics=physics, control=Control.state, device="cpu") +@pytest.mark.parametrize("dynamics", Dynamics) +def test_scan_results(dynamics: Dynamics): + sim = Sim(n_worlds=2, n_drones=3, dynamics=dynamics, control=Control.state, device="cpu") sim.reset() cmd = np.zeros((sim.n_worlds, sim.n_drones, 13)) cmd[..., :3] = sim.data.states.pos + np.array([0.3, 0.3, 0.3]) @@ -466,13 +468,13 @@ def test_scan_results(physics: Physics): @pytest.mark.unit -@pytest.mark.parametrize("drone_model", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"]) +@pytest.mark.parametrize("drone", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"]) @pytest.mark.parametrize("mat_name", ["led_top", "led_bot"]) -def test_change_material(device: str, drone_model: str, mat_name: str): +def test_change_material(device: str, drone: str, mat_name: str): """change_material should broadcast RGBA/emission and update MuJoCo materials appropriately.""" n_drones = 2 - sim = Sim(n_drones=n_drones, drone_model=drone_model, device=device) + sim = Sim(n_drones=n_drones, drone=drone, device=device) drone_ids = np.array([0, 1], dtype=int) rgba = 0.42 * np.ones((n_drones, 4), dtype=float) @@ -526,9 +528,9 @@ def test_build_data(control: Control): @pytest.mark.unit -@pytest.mark.parametrize("drone_model", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"]) -def test_fused_model(device: str, drone_model: str): - sim = Sim(drone_model=drone_model, fused_mjx_model=True, device=device) +@pytest.mark.parametrize("drone", ["cf2x_L250", "cf2x_P250", "cf2x_T350", "cf21B_500"]) +def test_fused_model(device: str, drone: str): + sim = Sim(drone=drone, fused_mjx_model=True, device=device) sim.reset() sim.step(1) sim.close() diff --git a/tests/unit/test_symbolic.py b/tests/unit/test_symbolic.py deleted file mode 100644 index 0351299..0000000 --- a/tests/unit/test_symbolic.py +++ /dev/null @@ -1,32 +0,0 @@ -import casadi as cs -import pytest - -from crazyflow.control import Control -from crazyflow.sim import Sim -from crazyflow.sim.symbolic import symbolic_from_sim - - -@pytest.mark.unit -@pytest.mark.parametrize("n_worlds", [1, 2]) -def test_symbolic_from_sim(n_worlds: int): - """Test creating symbolic model from sim instance.""" - sim = Sim(n_worlds=n_worlds, n_drones=1, control=Control.attitude) - X_dot, X, U, Y = symbolic_from_sim(sim) - - assert isinstance(X_dot, cs.MX) - assert isinstance(X, cs.MX) - assert isinstance(U, cs.MX) - assert isinstance(Y, cs.MX) - assert X_dot.shape == (13, 1) - assert X.shape == (13, 1) - assert U.shape == (4, 1) - assert Y.shape == (7, 1) - - -@pytest.mark.unit -@pytest.mark.parametrize("control", [Control.state, Control.force_torque]) -def test_symbolic_from_sim_errors(control: Control): - """Test creating symbolic model from sim instance.""" - sim = Sim(control=control) - with pytest.raises(ValueError, match="Symbolic model dynamics only support"): - symbolic_from_sim(sim)