diff --git a/src/smpclient/transport/serial/__init__.py b/src/smpclient/transport/serial/__init__.py new file mode 100644 index 0000000..50d5009 --- /dev/null +++ b/src/smpclient/transport/serial/__init__.py @@ -0,0 +1,11 @@ +"""Serial SMPTransports. + +In addition to UART, these transports can be used with USB CDC ACM and CAN. +""" + +from smpclient.transport.serial.encoded import Auto as Auto +from smpclient.transport.serial.encoded import BufferParams as BufferParams +from smpclient.transport.serial.encoded import BufferSize as BufferSize +from smpclient.transport.serial.encoded import FragmentationStrategy as FragmentationStrategy +from smpclient.transport.serial.encoded import SMPSerialTransport as SMPSerialTransport +from smpclient.transport.serial.unencoded import SMPSerialRawTransport as SMPSerialRawTransport diff --git a/src/smpclient/transport/serial/common.py b/src/smpclient/transport/serial/common.py new file mode 100644 index 0000000..aa0866e --- /dev/null +++ b/src/smpclient/transport/serial/common.py @@ -0,0 +1,156 @@ +"""Shared connection management for the encoded and unencoded serial transports.""" + +import asyncio +import logging +from contextlib import contextmanager +from time import monotonic +from typing import Final, Generator, final + +try: + from serial import Serial, SerialException +except ModuleNotFoundError as e: + if e.name == "serial": + raise ImportError( + "Serial transport requires the 'serial' extra. Use smpclient[serial]" + ) from e + raise +from typing_extensions import override + +from smpclient.transport import SMPTransport, SMPTransportDisconnected + +logger = logging.getLogger(__name__) + + +class _SerialTransportBase(SMPTransport): + """Connection-management base class for serial-port-backed SMP transports. + + Holds the `pyserial` `Serial` instance, the open/retry connect loop, disconnect, + and the small TX/RX helpers that wrap `SerialException` into + `SMPTransportDisconnected`. + + Subclasses implement `send` and `receive` with their framing of choice, may + override `_reset_state` to clear per-connection state on `connect`, and may + override `connect` to back the transport with a byte pipe other than a local + serial port (e.g. an emulator's `socket://` chardev). + """ + + _POLLING_INTERVAL_S: Final = 0.005 + _CONNECTION_RETRY_INTERVAL_S: Final = 0.500 + + def __init__( + self, + baudrate: int = 115200, + bytesize: int = 8, + parity: str = "N", + stopbits: float = 1, + timeout: float | None = None, + xonxoff: bool = False, + rtscts: bool = False, + write_timeout: float | None = None, + dsrdtr: bool = False, + inter_byte_timeout: float | None = None, + exclusive: bool | None = None, + ) -> None: + """Initialize the underlying `pyserial` `Serial` instance. + + Args: + baudrate: The baudrate of the serial connection. OK to ignore for + USB CDC ACM. + bytesize: The number of data bits. + parity: The parity setting. + stopbits: The number of stop bits. + timeout: The read timeout. + xonxoff: Enable software flow control. + rtscts: Enable hardware (RTS/CTS) flow control. + write_timeout: The write timeout. + dsrdtr: Enable hardware (DSR/DTR) flow control. + inter_byte_timeout: The inter-byte timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. + """ + self._conn: Final = Serial( + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + timeout=timeout, + xonxoff=xonxoff, + rtscts=rtscts, + write_timeout=write_timeout, + dsrdtr=dsrdtr, + inter_byte_timeout=inter_byte_timeout, + exclusive=exclusive, + ) + + def _reset_state(self) -> None: + """Reset any per-connection state. Subclasses override as needed.""" + + @override + async def connect(self, address: str, timeout_s: float) -> None: + self._reset_state() + self._conn.port = address + logger.debug(f"Connecting to {self._conn.port=}") + start_time: Final = monotonic() + while monotonic() - start_time <= timeout_s: + try: + self._conn.open() + self._conn.reset_input_buffer() + logger.debug(f"Connected to {self._conn.port=}") + return + except SerialException as e: + logger.debug( + f"Failed to connect to {self._conn.port=}: {e}, " + f"retrying in {self._CONNECTION_RETRY_INTERVAL_S} seconds" + ) + await asyncio.sleep(self._CONNECTION_RETRY_INTERVAL_S) + + raise TimeoutError(f"Failed to connect to {address=}") + + @final + @override + async def disconnect(self) -> None: + logger.debug(f"Disconnecting from {self._conn.port=}") + self._conn.close() + logger.debug(f"Disconnected from {self._conn.port=}") + + @final + @override + async def send_and_receive(self, data: bytes) -> bytes: + await self.send(data) + return await self.receive() + + @final + @contextmanager + def _serial_exception_to_disconnected(self) -> Generator[None, None, None]: + """Translate `SerialException` from `pyserial` to `SMPTransportDisconnected`.""" + try: + yield + except SerialException as e: + logger.error(f"Serial exception on {self._conn.port}: {e}") + raise SMPTransportDisconnected( + f"{self.__class__.__name__} disconnected from {self._conn.port}" + ) from e + + @final + async def _drain_tx(self) -> None: + """Block until the serial TX buffer is empty. + + Fake-async polling until `pyserial` is replaced. + """ + while self._conn.out_waiting > 0: + await asyncio.sleep(self._POLLING_INTERVAL_S) + + @final + async def _read_all(self) -> bytes: + """Return all currently-available bytes (or empty bytes). + + Wraps `SerialException` into `SMPTransportDisconnected`. `StopIteration` is + caught to keep mocked `read_all` side-effect lists usable in tests. + """ + try: + return self._conn.read_all() or b"" + except StopIteration: + return b"" + except SerialException as exc: + raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") from exc diff --git a/src/smpclient/transport/serial.py b/src/smpclient/transport/serial/encoded.py similarity index 92% rename from src/smpclient/transport/serial.py rename to src/smpclient/transport/serial/encoded.py index 8939aa7..d2e322b 100644 --- a/src/smpclient/transport/serial.py +++ b/src/smpclient/transport/serial/encoded.py @@ -1,4 +1,4 @@ -"""A serial `SMPTransport` for UART, USB CDC ACM, and CAN. +"""The base64-encoded serial `SMPTransport` for UART, USB CDC ACM, and CAN. An SMP serial frame wraps the SMP message as `[uint16 length][message][uint16 CRC16]`, base64-encodes it, and splits it into lines (<= 128 bytes by convention) on the wire. @@ -8,6 +8,12 @@ many bytes once base64-encoded and line-framed, so the transport puts more than `buf_size` encoded bytes on the wire -- which the server decodes incrementally. +This is what Zephyr calls "SMP over console" -- the framing shared by +`CONFIG_MCUMGR_TRANSPORT_UART` and `CONFIG_MCUMGR_TRANSPORT_SHELL`, and the only +SMP-over-UART option that existed before Zephyr 4.4. For +`CONFIG_MCUMGR_TRANSPORT_RAW_UART` servers, use `SMPSerialRawTransport` from +`smpclient.transport.serial.unencoded`. + The transport fills that decoded buffer for best throughput; how it learns the buffer size is the `fragmentation_strategy` (`FragmentationStrategy`) -- see `Auto` (the default), `BufferSize`, and `BufferParams`. @@ -16,23 +22,14 @@ import asyncio import logging import math -import time import warnings from enum import IntEnum, unique from typing import Final, NamedTuple, TypeAlias -try: - from serial import Serial, SerialException -except ModuleNotFoundError as e: - if e.name == "serial": - raise ImportError( - "Serial transport requires the 'serial' extra. Use smpclient[serial]" - ) from e - raise from smp import packet as smppacket from typing_extensions import assert_never, deprecated, overload, override -from smpclient.transport import SMPTransport, SMPTransportDisconnected +from smpclient.transport.serial.common import _SerialTransportBase logger = logging.getLogger(__name__) @@ -177,10 +174,7 @@ class _LegacyParams(NamedTuple): """The internal strategy a constructor call resolves to (adds the deprecated `_LegacyParams`).""" -class SMPSerialTransport(SMPTransport): - _POLLING_INTERVAL_S = 0.005 - _CONNECTION_RETRY_INTERVAL_S = 0.500 - +class SMPSerialTransport(_SerialTransportBase): @unique class BufferState(IntEnum): SMP = 0 @@ -303,10 +297,7 @@ def __init__( # noqa: DOC301 exclusive: The exclusive access timeout. """ - self._fragmentation_strategy: Final = self._resolve_fragmentation_strategy( - fragmentation_strategy, max_smp_encoded_frame_size, line_length, line_buffers - ) - self._conn: Final = Serial( + super().__init__( baudrate=baudrate, bytesize=bytesize, parity=parity, @@ -320,6 +311,10 @@ def __init__( # noqa: DOC301 exclusive=exclusive, ) + self._fragmentation_strategy: Final = self._resolve_fragmentation_strategy( + fragmentation_strategy, max_smp_encoded_frame_size, line_length, line_buffers + ) + self._smp_packet_queue: asyncio.Queue[bytes] = asyncio.Queue() """Contains full SMP packets.""" self._serial_buffer = bytearray() @@ -441,6 +436,7 @@ def _validate_line_length(line_length: int) -> None: f"cannot carry a base64 payload and would stall fragmentation" ) + @override def _reset_state(self) -> None: """Reset internal state and queues for a fresh connection.""" self._smp_packet_queue = asyncio.Queue() @@ -552,33 +548,6 @@ def initialize(self, smp_server_transport_buffer_size: int) -> None: case _ as unreachable: assert_never(unreachable) - @override - async def connect(self, address: str, timeout_s: float) -> None: - self._reset_state() - self._conn.port = address - logger.debug(f"Connecting to {self._conn.port=}") - start_time: Final = time.time() - while time.time() - start_time <= timeout_s: - try: - self._conn.open() - self._conn.reset_input_buffer() - logger.debug(f"Connected to {self._conn.port=}") - return - except SerialException as e: - logger.debug( - f"Failed to connect to {self._conn.port=}: {e}, " - f"retrying in {SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S} seconds" - ) - await asyncio.sleep(SMPSerialTransport._CONNECTION_RETRY_INTERVAL_S) - - raise TimeoutError(f"Failed to connect to {address=}") - - @override - async def disconnect(self) -> None: - logger.debug(f"Disconnecting from {self._conn.port=}") - self._conn.close() - logger.debug(f"Disconnected from {self._conn.port=}") - @override async def send(self, data: bytes) -> None: if len(data) > self.max_unencoded_size: @@ -586,19 +555,12 @@ async def send(self, data: bytes) -> None: f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}" ) logger.debug(f"Sending {len(data)} bytes") - try: + with self._serial_exception_to_disconnected(): for packet in smppacket.encode(data, line_length=self._line_length): self._conn.write(packet) logger.debug(f"Writing encoded packet of size {len(packet)}B; {self._line_length=}") - # fake async until I get around to replacing pyserial - while self._conn.out_waiting > 0: - await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S) - except SerialException as e: - logger.error(f"Failed to send {len(data)} bytes: {e}") - raise SMPTransportDisconnected( - f"{self.__class__.__name__} disconnected from {self._conn.port}" - ) + await self._drain_tx() logger.debug(f"Sent {len(data)} bytes") @@ -652,18 +614,13 @@ async def read_serial(self, delimiter: bytes | None = None) -> bytes: async def _read_and_process(self, read_until_one_smp_packet: bool) -> None: """Reads raw data from serial and processes it into SMP packets and regular serial data.""" while True: - try: - data = self._conn.read_all() or b"" - except StopIteration: - data = b"" - except SerialException as exc: - raise SMPTransportDisconnected(f"Failed to read from {self._conn.port}: {exc}") + data = await self._read_all() if data: self._buffer.extend(data) await self._process_buffer() else: - await asyncio.sleep(SMPSerialTransport._POLLING_INTERVAL_S) + await asyncio.sleep(self._POLLING_INTERVAL_S) if read_until_one_smp_packet: if self._smp_packet_queue.qsize(): @@ -752,11 +709,6 @@ def _could_be_smp_packet_start(self, byte: int) -> bool: """Return True if the given byte value matches the start of any SMP packet delimiter.""" return byte == smppacket.START_DELIMITER[0] or byte == smppacket.CONTINUE_DELIMITER[0] - @override - async def send_and_receive(self, data: bytes) -> bytes: - await self.send(data) - return await self.receive() - @override @property def mtu(self) -> int: diff --git a/src/smpclient/transport/serial/unencoded.py b/src/smpclient/transport/serial/unencoded.py new file mode 100644 index 0000000..05dcf0c --- /dev/null +++ b/src/smpclient/transport/serial/unencoded.py @@ -0,0 +1,138 @@ +"""The unencoded (raw) serial SMPTransport. + +This is the Zephyr "raw UART" SMP transport, enabled on the server by +`CONFIG_MCUMGR_TRANSPORT_RAW_UART` together with `CONFIG_UART_MCUMGR_RAW_PROTOCOL`. +Each SMP message is sent over the wire as the raw bytes +`[8-byte SMP header][header.length bytes of payload]` with no framing, encoding, +or CRC. The receiver parses the SMP header to determine the message length. + +This transport cannot coexist with shell or log output on the same UART. If +you need shell interleaving, use `SMPSerialTransport` from +`smpclient.transport.serial.encoded`. +""" + +import asyncio +import logging +from typing import Final + +from smp import header as smphdr +from typing_extensions import override + +from smpclient.exceptions import SMPClientException +from smpclient.transport.serial.common import _SerialTransportBase + +logger = logging.getLogger(__name__) + + +class SMPSerialRawTransport(_SerialTransportBase): + def __init__( + self, + mtu: int = 384, + baudrate: int = 115200, + bytesize: int = 8, + parity: str = "N", + stopbits: float = 1, + timeout: float | None = None, + xonxoff: bool = False, + rtscts: bool = False, + write_timeout: float | None = None, + dsrdtr: bool = False, + inter_byte_timeout: float | None = None, + exclusive: bool | None = None, + ) -> None: + """Initialize the raw serial transport. + + Args: + mtu: The maximum size of one SMP message (header + payload), in + bytes. A serial link has no MTU of its own, but the SMP + server's receive buffer does -- this should match the server's + `CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE` (Zephyr default 384). + baudrate: The baudrate of the serial connection. OK to ignore for + USB CDC ACM. + bytesize: The number of data bits. + parity: The parity setting. + stopbits: The number of stop bits. + timeout: The read timeout. + xonxoff: Enable software flow control. + rtscts: Enable hardware (RTS/CTS) flow control. + write_timeout: The write timeout. + dsrdtr: Enable hardware (DSR/DTR) flow control. + inter_byte_timeout: The inter-byte timeout. + exclusive: Set exclusive access mode (POSIX only). A port cannot be + opened in exclusive access mode if it is already open in + exclusive access mode. + """ + super().__init__( + baudrate=baudrate, + bytesize=bytesize, + parity=parity, + stopbits=stopbits, + timeout=timeout, + xonxoff=xonxoff, + rtscts=rtscts, + write_timeout=write_timeout, + dsrdtr=dsrdtr, + inter_byte_timeout=inter_byte_timeout, + exclusive=exclusive, + ) + self._mtu: Final = mtu + + logger.debug(f"Initialized {self.__class__.__name__}") + + @override + async def send(self, data: bytes) -> None: + if len(data) > self.max_unencoded_size: + raise ValueError( + f"Data size {len(data)} exceeds maximum unencoded size {self.max_unencoded_size}" + ) + logger.debug(f"Sending {len(data)} bytes") + with self._serial_exception_to_disconnected(): + self._conn.write(data) + await self._drain_tx() + logger.debug(f"Sent {len(data)} bytes") + + @override + async def receive(self) -> bytes: + logger.debug("Waiting for response") + message = bytearray() + + while len(message) < smphdr.Header.SIZE: + await self._poll_read_into(message) + + header: Final = smphdr.Header.loads(bytes(message[: smphdr.Header.SIZE])) + message_length: Final = header.length + smphdr.Header.SIZE + logger.debug(f"Received {header=}; awaiting {message_length} B total") + + # The header's length field is attacker/noise-controlled - bound it before + # we start waiting for that many bytes to arrive. + if message_length > self.max_unencoded_size: + error = ( + f"Header claims a {message_length} B message, " + f"exceeding max_unencoded_size={self.max_unencoded_size}" + ) + logger.error(error) + raise SMPClientException(error) + + while len(message) < message_length: + await self._poll_read_into(message) + + if len(message) > message_length: + error = f"Received more data than expected: {len(message)} B > {message_length} B" + logger.error(error) + raise SMPClientException(error) + + logger.debug(f"Finished receiving {message_length} B response") + return bytes(message) + + async def _poll_read_into(self, buf: bytearray) -> None: + """Read available bytes into `buf`; if none, yield via a short sleep.""" + data = await self._read_all() + if data: + buf.extend(data) + else: + await asyncio.sleep(self._POLLING_INTERVAL_S) + + @override + @property + def mtu(self) -> int: + return self._mtu diff --git a/tests/fixtures/smp-server/zephyr_4.4.0_smp_server_0eae053d_native_sim_serial_raw.exe b/tests/fixtures/smp-server/zephyr_4.4.0_smp_server_0eae053d_native_sim_serial_raw.exe new file mode 100755 index 0000000..af04593 Binary files /dev/null and b/tests/fixtures/smp-server/zephyr_4.4.0_smp_server_0eae053d_native_sim_serial_raw.exe differ diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d8bc5b9..7db676e 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -25,7 +25,7 @@ from smpclient.generics import success from smpclient.requests.os_management import EchoWrite from smpclient.transport import SMPTransport -from smpclient.transport.serial import SMPSerialTransport +from smpclient.transport.serial import SMPSerialRawTransport, SMPSerialTransport from smpclient.transport.udp import SMPUDPTransport from tests.integration.servers import ( FIXTURES, @@ -72,10 +72,18 @@ def fixture_params( """`SMPUDPTransport.connect`'s default port; `SMPClient.connect` cannot override it.""" -def _build_transport(endpoint: Endpoint) -> tuple[SMPTransport, str]: +def _build_transport(fixture: ServerFixture, endpoint: Endpoint) -> tuple[SMPTransport, str]: match endpoint: case PtyEndpoint(pty): - return SMPSerialTransport(), pty + match fixture.transport: + case "serial" | "shell": + return SMPSerialTransport(), pty + case "serial_raw": + return SMPSerialRawTransport(), pty + case "udp": + pytest.fail("UDP fixtures do not present as a PTY serial endpoint") + case _ as unreachable: + assert_never(unreachable) case SocketSerialEndpoint(url): return QemuSocketSerialTransport(url), url case UdpEndpoint(host, port): @@ -159,7 +167,7 @@ def assert_chunks_maximized( async def connected(fixture: ServerFixture) -> AsyncIterator[ConnectedServer]: """Launch `fixture`, connect an `SMPClient`, and wait until the server answers.""" async with serve(fixture) as endpoint: - transport, address = _build_transport(endpoint) + transport, address = _build_transport(fixture, endpoint) client = SMPClient(transport, address) await client.connect() await _wait_until_answering(client) diff --git a/tests/integration/servers.py b/tests/integration/servers.py index 475bdf1..45455a4 100644 --- a/tests/integration/servers.py +++ b/tests/integration/servers.py @@ -111,11 +111,6 @@ def path(self) -> Path: def emulated(self) -> bool: return self.qemu_cmd is not None - @property - def client_supported(self) -> bool: - """`False` when smpclient has no transport that speaks this server's protocol.""" - return self.transport != "serial_raw" - @property def params_supported(self) -> bool: """`False` for builds with the MCUmgr params command disabled (`noparams`).""" @@ -156,8 +151,6 @@ def has_group(self, group: str) -> bool: def unavailable_reason(self) -> str | None: """Return why this fixture cannot run on this host, or `None` if it can.""" - if not self.client_supported: - return "smpclient has no raw-UART (CONFIG_MCUMGR_TRANSPORT_RAW_UART) transport yet" if platform.system() != "Linux": return "SMP server fixtures run on Linux only" if not self.path.is_file(): diff --git a/tests/integration/test_serial_raw.py b/tests/integration/test_serial_raw.py new file mode 100644 index 0000000..4148e11 --- /dev/null +++ b/tests/integration/test_serial_raw.py @@ -0,0 +1,32 @@ +"""Raw-UART (`CONFIG_MCUMGR_TRANSPORT_RAW_UART`) serial integration tests. + +`SMPSerialRawTransport` round-trips (echo, enumeration, MCUmgr params) are exercised by +the generic `connected_server` suite; these cover what is specific to the raw transport. +""" + +from __future__ import annotations + +import pytest + +from smpclient.transport.serial import SMPSerialRawTransport +from tests.integration.conftest import ConnectedServer + +pytestmark = [pytest.mark.integration, pytest.mark.asyncio] + + +async def test_raw_transport_sizes_messages_to_server_buf_size( + connected_server: ConnectedServer, +) -> None: + """The raw transport caps an SMP message at the server's reported `buf_size`. + + `_initialize` reads the server's MCUmgr `buf_size` and the raw transport adopts it + verbatim as `max_unencoded_size` -- the chunk size `SMPClient.upload` fills -- with + no on-wire framing to subtract (the whole `[header][payload]` rides in the netbuf). + """ + cs = connected_server + if cs.fixture.transport != "serial_raw": + pytest.skip("raw UART transport") + + transport = cs.client._transport + assert isinstance(transport, SMPSerialRawTransport) + assert transport.max_unencoded_size == cs.fixture.buf_size diff --git a/tests/test_base64.py b/tests/test_base64.py index e1f600f..6f26ca9 100644 --- a/tests/test_base64.py +++ b/tests/test_base64.py @@ -3,7 +3,7 @@ import random from base64 import b64encode -from smpclient.transport.serial import _base64_cost, _base64_max +from smpclient.transport.serial.encoded import _base64_cost, _base64_max if not hasattr(random, 'randbytes'): from os import urandom diff --git a/tests/test_smp_client.py b/tests/test_smp_client.py index 8a6d307..22666e8 100644 --- a/tests/test_smp_client.py +++ b/tests/test_smp_client.py @@ -36,7 +36,12 @@ from smpclient.requests.file_management import FileDownload, FileUpload from smpclient.requests.image_management import ImageUploadWrite from smpclient.requests.os_management import ResetWrite -from smpclient.transport.serial import BufferParams, BufferSize, SMPSerialTransport +from smpclient.transport.serial import ( + BufferParams, + BufferSize, + SMPSerialRawTransport, + SMPSerialTransport, +) FRAME_OVERHEAD = smppacket.FRAME_LENGTH_STRUCT.size + smppacket.CRC16_STRUCT.size """The SMP serial frame's 2-byte length + 2-byte CRC16 that share the decoded buffer.""" @@ -399,6 +404,53 @@ async def mock_request( assert reconstructed_image == image +@pytest.mark.asyncio +@pytest.mark.parametrize("mtu", [128, 256, 512, 1024, 2048, 4096, 8192]) +async def test_upload_hello_world_bin_raw(mtu: int) -> None: + with open( + str(Path("tests", "fixtures", "zephyr-v3.5.0-2795-g28ff83515d", "hello_world.signed.bin")), + 'rb', + ) as f: + image = f.read() + + m = SMPSerialRawTransport(mtu=mtu) + s = SMPClient(m, "address") + assert s._transport.mtu == mtu + assert s._transport.max_unencoded_size == mtu, "The raw transport has no encoding overhead" + + packets: list[bytes] = [] + + def mock_write(data: bytes) -> int: + """Accumulate the raw packets in the global `packets`.""" + packets.append(data) + return len(data) + + s._transport._conn.write = mock_write # type: ignore + + async def mock_request( + request: ImageUploadWrite, timeout_s: float = 120.000 + ) -> ImageUploadWriteResponse: + # call the real send method (with write mocked) but don't bother with receive + # this provides coverage for the MTU-limited chunking done by SMPClient.upload + await s._transport.send(request.BYTES) + return ImageUploadWrite._Response.get_default()(off=request.off + len(request.data)) # type: ignore # noqa + + s.request = mock_request # type: ignore + + # `out_waiting` is a property on the real Serial class - scope the patch so it + # restores cleanly when the test finishes. + with patch.object(type(s._transport._conn), 'out_waiting', 0): # type: ignore + async for _ in s.upload(image): + pass + + # Each captured write is one complete SMP message [header][payload], no decoding needed. + reconstructed_image = bytearray([]) + for packet in packets: + reconstructed_image.extend(ImageUploadWriteRequest.loads(packet).data) + + assert reconstructed_image == image + + @pytest.mark.asyncio async def test_upload_file() -> None: m = SMPMockTransport() diff --git a/tests/test_smp_serial_raw_transport.py b/tests/test_smp_serial_raw_transport.py new file mode 100644 index 0000000..53ef77e --- /dev/null +++ b/tests/test_smp_serial_raw_transport.py @@ -0,0 +1,254 @@ +"""Tests for `SMPSerialRawTransport`.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Generator +from typing import Any +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest +from serial import SerialException +from smp import header as smphdr + +from smpclient.exceptions import SMPClientException +from smpclient.requests.os_management import EchoWrite +from smpclient.transport import SMPTransportDisconnected +from smpclient.transport.serial import SMPSerialRawTransport + + +@pytest.fixture(autouse=True) +def mock_serial() -> Generator[None, Any, None]: + with patch("smpclient.transport.serial.common.Serial"): + yield + + +def test_constructor() -> None: + t = SMPSerialRawTransport(mtu=512) + assert t.mtu == 512 + assert t.max_unencoded_size == 512 + + +def test_constructor_defaults() -> None: + t = SMPSerialRawTransport() + assert t.mtu == 384 + + +@pytest.mark.asyncio +async def test_connect_disconnect() -> None: + ports: list[str] = ["COM2", "/dev/ttyACM0", "/dev/ttyUSB0"] + + t = SMPSerialRawTransport() + t._conn.read_all = MagicMock(return_value=b"") # type: ignore + + for p in ports: + await asyncio.wait_for(t.connect(p, 1.0), timeout=1.0) + t._conn.open.assert_called_once() # type: ignore + + assert t._conn.port == p + + await asyncio.wait_for(t.disconnect(), timeout=0.1) + t._conn.close.assert_called_once() # type: ignore + + t._conn.reset_mock() # type: ignore + + +@pytest.mark.asyncio +async def test_connect_retries_until_timeout() -> None: + t = SMPSerialRawTransport() + t._conn.open = MagicMock(side_effect=SerialException("nope")) # type: ignore + + with pytest.raises(TimeoutError): + await asyncio.wait_for(t.connect("/dev/ttyUSB0", 0.1), timeout=2.0) + + +@pytest.mark.asyncio +async def test_send() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock() # type: ignore + p = PropertyMock(return_value=0) + type(t._conn).out_waiting = p # type: ignore + + r = EchoWrite(d="Hello pytest!") + await t.send(r.BYTES) + + # Raw transport writes the bytes verbatim - no encoding. + t._conn.write.assert_called_once_with(r.BYTES) + p.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_send_waits_for_tx_drain() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock() # type: ignore + p = PropertyMock(side_effect=(1, 0)) + type(t._conn).out_waiting = p # type: ignore + + await t.send(EchoWrite(d="x").BYTES) + assert p.call_count == 2 + + +@pytest.mark.asyncio +async def test_send_too_large_raises() -> None: + t = SMPSerialRawTransport(mtu=16) + with pytest.raises(ValueError): + await t.send(b"\x00" * 32) + + +@pytest.mark.asyncio +async def test_send_disconnected_raises() -> None: + t = SMPSerialRawTransport() + t._conn.write = MagicMock(side_effect=SerialException("disconnected")) # type: ignore + + with pytest.raises(SMPTransportDisconnected): + await t.send(EchoWrite(d="x").BYTES) + + +@pytest.mark.asyncio +async def test_receive_single_packet() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello pytest!") # type: ignore + t._conn.read_all = MagicMock(side_effect=[m.BYTES]) # type: ignore + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_fragmented() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello pytest!") # type: ignore + fragments = [ + m.BYTES[:3], # less than a header + m.BYTES[3:8], # completes the header but no payload yet + m.BYTES[8:10], + m.BYTES[10:], # rest of payload + ] + t._conn.read_all = MagicMock(side_effect=fragments) # type: ignore + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_byte_at_a_time() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hi") # type: ignore + t._conn.read_all = MagicMock( # type: ignore + side_effect=[bytes([b]) for b in m.BYTES] + ) + + received = await t.receive() + assert received == m.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_consecutive_messages() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m1 = EchoWrite._Response.get_default()(sequence=0, r="SMP Message 1") # type: ignore + m2 = EchoWrite._Response.get_default()(sequence=1, r="SMP Message 2") # type: ignore + m3 = EchoWrite._Response.get_default()(sequence=2, r="SMP Message 3") # type: ignore + + # Each receive() reads one full message, just like a normal request/response loop. + t._conn.read_all = MagicMock(side_effect=[m1.BYTES, m2.BYTES, m3.BYTES]) # type: ignore + + assert await t.receive() == m1.BYTES + assert await t.receive() == m2.BYTES + assert await t.receive() == m3.BYTES + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_overrun_raises() -> None: + """A single read returning more bytes than the header advertises is an error. + + SMP is strictly request/response; the server should never send unsolicited bytes. + """ + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="Hello!") # type: ignore + t._conn.read_all = MagicMock(side_effect=[m.BYTES + b"\x00\x01\x02"]) # type: ignore + + with pytest.raises(SMPClientException): + await t.receive() + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_polls_when_nothing_available() -> None: + t = SMPSerialRawTransport() + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + m = EchoWrite._Response.get_default()(sequence=0, r="ok") # type: ignore + t._conn.read_all = MagicMock(side_effect=[b"", b"", m.BYTES]) # type: ignore + + received = await t.receive() + assert received == m.BYTES + assert t._conn.read_all.call_count >= 3 + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_oversized_header_raises() -> None: + """A header claiming more bytes than max_unencoded_size is rejected. + + Defensive bound against noisy or corrupted UART traffic that would + otherwise cause an unbounded wait. + """ + t = SMPSerialRawTransport(mtu=64) + await t.connect("/dev/ttyUSB0", timeout_s=1.0) + + bogus_header = smphdr.Header( + op=smphdr.OP.WRITE_RSP, + version=smphdr.Version.V2, + flags=smphdr.Flag(0), + length=10_000, + group_id=smphdr.GroupId.OS_MANAGEMENT, + sequence=0, + command_id=smphdr.CommandId.OSManagement.ECHO, + ).BYTES + t._conn.read_all = MagicMock(side_effect=[bogus_header]) # type: ignore + + with pytest.raises(SMPClientException): + await t.receive() + + await t.disconnect() + + +@pytest.mark.asyncio +async def test_receive_disconnected_raises() -> None: + t = SMPSerialRawTransport() + t._conn.read_all = MagicMock(side_effect=SerialException("disconnected")) # type: ignore + + with pytest.raises(SMPTransportDisconnected): + await t.receive() + + +@pytest.mark.asyncio +async def test_send_and_receive() -> None: + t = SMPSerialRawTransport() + t.send = AsyncMock() # type: ignore + t.receive = AsyncMock() # type: ignore + + await t.send_and_receive(b"some data") + + t.send.assert_awaited_once_with(b"some data") + t.receive.assert_awaited_once_with() diff --git a/tests/test_smp_serial_transport.py b/tests/test_smp_serial_transport.py index ab39d77..077aa2e 100644 --- a/tests/test_smp_serial_transport.py +++ b/tests/test_smp_serial_transport.py @@ -29,7 +29,7 @@ @pytest.fixture(autouse=True) def mock_serial() -> Generator[None, Any, None]: - with patch("smpclient.transport.serial.Serial"): + with patch("smpclient.transport.serial.common.Serial"): yield