Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/smpclient/transport/serial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Serial SMPTransports.

In addition to UART, these transports can be used with USB CDC ACM and CAN.
"""

from smpclient.transport.serial.encoded import Auto as Auto
from smpclient.transport.serial.encoded import BufferParams as BufferParams
from smpclient.transport.serial.encoded import BufferSize as BufferSize
from smpclient.transport.serial.encoded import FragmentationStrategy as FragmentationStrategy
from smpclient.transport.serial.encoded import SMPSerialTransport as SMPSerialTransport
from smpclient.transport.serial.unencoded import SMPSerialRawTransport as SMPSerialRawTransport
156 changes: 156 additions & 0 deletions src/smpclient/transport/serial/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Shared connection management for the encoded and unencoded serial transports."""

import asyncio
import logging
from contextlib import contextmanager
from time import monotonic
from typing import Final, Generator, final

try:
from serial import Serial, SerialException
except ModuleNotFoundError as e:
if e.name == "serial":
raise ImportError(
"Serial transport requires the 'serial' extra. Use smpclient[serial]"
) from e
raise
from typing_extensions import override

from smpclient.transport import SMPTransport, SMPTransportDisconnected

logger = logging.getLogger(__name__)


class _SerialTransportBase(SMPTransport):
"""Connection-management base class for serial-port-backed SMP transports.

Holds the `pyserial` `Serial` instance, the open/retry connect loop, disconnect,
and the small TX/RX helpers that wrap `SerialException` into
`SMPTransportDisconnected`.

Subclasses implement `send` and `receive` with their framing of choice, may
override `_reset_state` to clear per-connection state on `connect`, and may
override `connect` to back the transport with a byte pipe other than a local
serial port (e.g. an emulator's `socket://` chardev).
"""

_POLLING_INTERVAL_S: Final = 0.005
_CONNECTION_RETRY_INTERVAL_S: Final = 0.500

def __init__(
self,
baudrate: int = 115200,
bytesize: int = 8,
parity: str = "N",
stopbits: float = 1,
timeout: float | None = None,
xonxoff: bool = False,
rtscts: bool = False,
write_timeout: float | None = None,
dsrdtr: bool = False,
inter_byte_timeout: float | None = None,
exclusive: bool | None = None,
) -> None:
"""Initialize the underlying `pyserial` `Serial` instance.

Args:
baudrate: The baudrate of the serial connection. OK to ignore for
USB CDC ACM.
bytesize: The number of data bits.
parity: The parity setting.
stopbits: The number of stop bits.
timeout: The read timeout.
xonxoff: Enable software flow control.
rtscts: Enable hardware (RTS/CTS) flow control.
write_timeout: The write timeout.
dsrdtr: Enable hardware (DSR/DTR) flow control.
inter_byte_timeout: The inter-byte timeout.
exclusive: Set exclusive access mode (POSIX only). A port cannot be
opened in exclusive access mode if it is already open in
exclusive access mode.
"""
self._conn: Final = Serial(
baudrate=baudrate,
bytesize=bytesize,
parity=parity,
stopbits=stopbits,
timeout=timeout,
xonxoff=xonxoff,
rtscts=rtscts,
write_timeout=write_timeout,
dsrdtr=dsrdtr,
inter_byte_timeout=inter_byte_timeout,
exclusive=exclusive,
)

def _reset_state(self) -> None:
"""Reset any per-connection state. Subclasses override as needed."""

@override
async def connect(self, address: str, timeout_s: float) -> None:
self._reset_state()
self._conn.port = address
logger.debug(f"Connecting to {self._conn.port=}")
start_time: Final = monotonic()
while monotonic() - start_time <= timeout_s:
try:
self._conn.open()
self._conn.reset_input_buffer()
logger.debug(f"Connected to {self._conn.port=}")
return
except SerialException as e:
logger.debug(
f"Failed to connect to {self._conn.port=}: {e}, "
f"retrying in {self._CONNECTION_RETRY_INTERVAL_S} seconds"
)
await asyncio.sleep(self._CONNECTION_RETRY_INTERVAL_S)
Comment thread
JPHutchins marked this conversation as resolved.
Comment thread
JPHutchins marked this conversation as resolved.

raise TimeoutError(f"Failed to connect to {address=}")

@final
@override
async def disconnect(self) -> None:
logger.debug(f"Disconnecting from {self._conn.port=}")
self._conn.close()
logger.debug(f"Disconnected from {self._conn.port=}")

@final
@override
async def send_and_receive(self, data: bytes) -> bytes:
await self.send(data)
return await self.receive()

@final
@contextmanager
def _serial_exception_to_disconnected(self) -> Generator[None, None, None]:
"""Translate `SerialException` from `pyserial` to `SMPTransportDisconnected`."""
try:
yield
except SerialException as e:
logger.error(f"Serial exception on {self._conn.port}: {e}")
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._conn.port}"
) from e

@final
async def _drain_tx(self) -> None:
"""Block until the serial TX buffer is empty.

Fake-async polling until `pyserial` is replaced.
"""
while self._conn.out_waiting > 0:
await asyncio.sleep(self._POLLING_INTERVAL_S)

@final
async def _read_all(self) -> bytes:
"""Return all currently-available bytes (or empty bytes).

Wraps `SerialException` into `SMPTransportDisconnected`. `StopIteration` is
caught to keep mocked `read_all` side-effect lists usable in tests.
"""
try:
return self._conn.read_all() or b""
except StopIteration:
return b""
except SerialException as exc:
raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") from exc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""A serial `SMPTransport` for UART, USB CDC ACM, and CAN.
"""The base64-encoded serial `SMPTransport` for UART, USB CDC ACM, and CAN.

An SMP serial frame wraps the SMP message as `[uint16 length][message][uint16 CRC16]`,
base64-encodes it, and splits it into lines (<= 128 bytes by convention) on the wire.
Expand All @@ -8,6 +8,12 @@
many bytes once base64-encoded and line-framed, so the transport puts more than
`buf_size` encoded bytes on the wire -- which the server decodes incrementally.

This is what Zephyr calls "SMP over console" -- the framing shared by
`CONFIG_MCUMGR_TRANSPORT_UART` and `CONFIG_MCUMGR_TRANSPORT_SHELL`, and the only
SMP-over-UART option that existed before Zephyr 4.4. For
`CONFIG_MCUMGR_TRANSPORT_RAW_UART` servers, use `SMPSerialRawTransport` from
`smpclient.transport.serial.unencoded`.

The transport fills that decoded buffer for best throughput; how it learns the buffer
size is the `fragmentation_strategy` (`FragmentationStrategy`) -- see `Auto` (the
default), `BufferSize`, and `BufferParams`.
Expand All @@ -16,23 +22,14 @@
import asyncio
import logging
import math
import time
import warnings
from enum import IntEnum, unique
from typing import Final, NamedTuple, TypeAlias

try:
from serial import Serial, SerialException
except ModuleNotFoundError as e:
if e.name == "serial":
raise ImportError(
"Serial transport requires the 'serial' extra. Use smpclient[serial]"
) from e
raise
from smp import packet as smppacket
from typing_extensions import assert_never, deprecated, overload, override

from smpclient.transport import SMPTransport, SMPTransportDisconnected
from smpclient.transport.serial.common import _SerialTransportBase

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -177,10 +174,7 @@ class _LegacyParams(NamedTuple):
"""The internal strategy a constructor call resolves to (adds the deprecated `_LegacyParams`)."""


class SMPSerialTransport(SMPTransport):
_POLLING_INTERVAL_S = 0.005
_CONNECTION_RETRY_INTERVAL_S = 0.500

class SMPSerialTransport(_SerialTransportBase):
@unique
class BufferState(IntEnum):
SMP = 0
Expand Down Expand Up @@ -303,10 +297,7 @@ def __init__( # noqa: DOC301
exclusive: The exclusive access timeout.

"""
self._fragmentation_strategy: Final = self._resolve_fragmentation_strategy(
fragmentation_strategy, max_smp_encoded_frame_size, line_length, line_buffers
)
self._conn: Final = Serial(
super().__init__(
baudrate=baudrate,
bytesize=bytesize,
parity=parity,
Expand All @@ -320,6 +311,10 @@ def __init__( # noqa: DOC301
exclusive=exclusive,
)

self._fragmentation_strategy: Final = self._resolve_fragmentation_strategy(
fragmentation_strategy, max_smp_encoded_frame_size, line_length, line_buffers
)

self._smp_packet_queue: asyncio.Queue[bytes] = asyncio.Queue()
"""Contains full SMP packets."""
self._serial_buffer = bytearray()
Expand Down Expand Up @@ -441,6 +436,7 @@ def _validate_line_length(line_length: int) -> None:
f"cannot carry a base64 payload and would stall fragmentation"
)

@override
def _reset_state(self) -> None:
"""Reset internal state and queues for a fresh connection."""
self._smp_packet_queue = asyncio.Queue()
Expand Down Expand Up @@ -552,53 +548,19 @@ def initialize(self, smp_server_transport_buffer_size: int) -> None:
case _ as unreachable:
assert_never(unreachable)

@override
async def connect(self, address: str, timeout_s: float) -> None:
self._reset_state()
self._conn.port = address
logger.debug(f"Connecting to {self._conn.port=}")
start_time: Final = time.time()
while time.time() - start_time <= timeout_s:
try:
self._conn.open()
self._conn.reset_input_buffer()
logger.debug(f"Connected to {self._conn.port=}")
return
except SerialException as e:
logger.debug(
f"Failed to connect to {self._conn.port=}: {e}, "
f"retrying in {SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S} seconds"
)
await asyncio.sleep(SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S)

raise TimeoutError(f"Failed to connect to {address=}")

@override
async def disconnect(self) -> None:
logger.debug(f"Disconnecting from {self._conn.port=}")
self._conn.close()
logger.debug(f"Disconnected from {self._conn.port=}")

@override
async def send(self, data: bytes) -> None:
if len(data) > self.max_unencoded_size:
raise ValueError(
f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}"
)
logger.debug(f"Sending {len(data)} bytes")
try:
with self._serial_exception_to_disconnected():
for packet in smppacket.encode(data, line_length=self._line_length):
self._conn.write(packet)
logger.debug(f"Writing encoded packet of size {len(packet)}B; {self._line_length=}")

# fake async until I get around to replacing pyserial
while self._conn.out_waiting > 0:
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
except SerialException as e:
logger.error(f"Failed to send {len(data)} bytes: {e}")
raise SMPTransportDisconnected(
f"{self.__class__.__name__} disconnected from {self._conn.port}"
)
await self._drain_tx()

logger.debug(f"Sent {len(data)} bytes")

Expand Down Expand Up @@ -652,18 +614,13 @@ async def read_serial(self, delimiter: bytes | None = None) -> bytes:
async def _read_and_process(self, read_until_one_smp_packet: bool) -> None:
"""Reads raw data from serial and processes it into SMP packets and regular serial data."""
while True:
try:
data = self._conn.read_all() or b""
except StopIteration:
data = b""
except SerialException as exc:
raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}")
data = await self._read_all()

if data:
self._buffer.extend(data)
await self._process_buffer()
else:
await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S)
await asyncio.sleep(self._POLLING_INTERVAL_S)

if read_until_one_smp_packet:
if self._smp_packet_queue.qsize():
Expand Down Expand Up @@ -752,11 +709,6 @@ def _could_be_smp_packet_start(self, byte: int) -> bool:
"""Return True if the given byte value matches the start of any SMP packet delimiter."""
return byte == smppacket.START_DELIMITER[0] or byte == smppacket.CONTINUE_DELIMITER[0]

@override
async def send_and_receive(self, data: bytes) -> bytes:
await self.send(data)
return await self.receive()

@override
@property
def mtu(self) -> int:
Expand Down
Loading
Loading