Skip to content

Latest commit

 

History

History
157 lines (114 loc) · 4.82 KB

File metadata and controls

157 lines (114 loc) · 4.82 KB

Public API

This guide covers the supported package surface for users building trajectory optimization problems with TBM.

Installation

Core solver with the default Clarabel-based backend:

pip install .

Optional extras:

pip install ".[cvxpy]"
pip install ".[mjwarp]"
pip install ".[examples]"
  • cvxpy: enables TBMConfig(subproblem_backend="cvxpy").
  • mjwarp: enables MjWarpBackend for MuJoCo/MJWarp rollouts.
  • examples: installs plotting dependencies used by example scripts.

Main Objects

The package exports these public entry points:

  • TBMProblem: problem definition, including rollout, costs, constraints, and bounds.
  • InitialGuess: initial controls plus optional guessed states.
  • TBMConfig: solver settings, backend selection, trust-region settings, and tolerances.
  • solve(problem, initial_guess, config=None, backend=None): solve a problem and return TBMResult.
  • TBMResult: solver status, iteration history, optimized states, and optimized controls.
  • MjWarpBackend: optional rollout backend available when the mjwarp extra is installed.

Shape Conventions

TBM expects batched callbacks.

Rollout callback:

def rollout(initial_states, control_segments):
    # initial_states.shape == (batch, state_dim)
    # control_segments.shape == (batch, steps, control_dim)
    # return shape == (batch, steps + 1, state_dim)
    ...

Stage callbacks receive one timestep index plus batched states and controls:

def stage_cost_residual(step, states, controls):
    # states.shape == (batch, state_dim)
    # controls.shape == (batch, control_dim)
    # return shape == (batch, residual_dim)
    ...

Terminal callbacks receive batched terminal states:

def terminal_cost_residual(states):
    # states.shape == (batch, state_dim)
    # return shape == (batch, residual_dim)
    ...

For convenience, TBM also accepts single-sample callback outputs when batch == 1.

Defining a Problem

TBMProblem requires:

  • state_dim
  • control_dim
  • horizon
  • initial_state
  • rollout or a separate backend passed to solve

Optional pieces:

  • segment_length: partial-shooting segment length. The default 1 means full shooting.
  • stage_cost_residual and terminal_cost_residual: least-squares objective terms.
  • stage_equality, terminal_equality, stage_inequality, terminal_inequality: constraints.
  • state_lower, state_upper, control_lower, control_upper: box constraints.
  • terminal_state_lower, terminal_state_upper: terminal-state bounds.

The solver minimizes squared residual norms plus slack penalties for dynamics and constraint violations.

Initial Guess

InitialGuess always requires controls with shape (horizon - 1, control_dim). You can optionally provide states with shape (horizon, state_dim) to warm start the segment start states.

Solver Configuration

Common TBMConfig fields:

  • max_iterations: outer TBM iterations.
  • state_trust_region, control_trust_region: trust-region sizes, either scalars or dense vectors.
  • step_tolerance, cost_tolerance, feasibility_tolerance: stopping thresholds.
  • subproblem_backend: "clarabel_direct" by default or "cvxpy" with the cvxpy extra.
  • verbose: prints outer-iteration diagnostics.
  • cvxpy_verbose: prints inner solver logs for either backend.

Backend guidance:

  • clarabel_direct: default, faster startup, no cvxpy dependency.
  • cvxpy: useful for backend comparisons or debugging canonicalized solver calls.

Result Object

solve() returns TBMResult with:

  • status: final solver status string.
  • converged: whether the configured stopping criteria were met.
  • iterations: number of outer iterations executed.
  • objective: nonlinear objective at the returned iterate.
  • max_constraint_violation: maximum evaluated constraint violation.
  • checkpoint_states: segment boundary states.
  • states: full rolled-out state trajectory with shape (horizon, state_dim).
  • controls: optimized controls with shape (horizon - 1, control_dim).
  • history: per-iteration timing and diagnostics.

Minimal Example

import numpy as np

from tbm import InitialGuess, TBMConfig, TBMProblem, solve


def rollout(initial_states, control_segments):
    states = np.zeros((initial_states.shape[0], control_segments.shape[1] + 1, 1))
    states[:, 0, :] = initial_states
    for step in range(control_segments.shape[1]):
        states[:, step + 1, 0] = states[:, step, 0] + control_segments[:, step, 0]
    return states


problem = TBMProblem(
    state_dim=1,
    control_dim=1,
    horizon=6,
    initial_state=np.array([0.0]),
    rollout=rollout,
    terminal_cost_residual=lambda x: x - np.array([[1.0]]),
)

guess = InitialGuess(controls=np.zeros((5, 1)))
result = solve(problem, guess, TBMConfig(max_iterations=20))
print(result.status, result.states[-1, 0])