From bfbe2a2e007eb0c7fc6ce5bd02ecc0e9793105e9 Mon Sep 17 00:00:00 2001 From: deeleeramone <> Date: Mon, 15 Jun 2026 17:28:30 -0700 Subject: [PATCH 1/2] merge branch develop --- pywry/pywry/chat/providers/deepagent.py | 60 +- pywry/pywry/cli.py | 3 - pywry/pywry/config.py | 2 +- pywry/pywry/inline.py | 199 +- pywry/pywry/toolbar.py | 15 +- pywry/pywry/tvchart/normalize.py | 2 - pywry/pywry/tvchart/toolbars.py | 3 - pywry/tests/conftest.py | 77 + pywry/tests/test_alerts.py | 7 - pywry/tests/test_asset_loader.py | 171 ++ pywry/tests/test_assets.py | 317 +++ pywry/tests/test_auth_callback_server.py | 72 + pywry/tests/test_auth_deploy_routes.py | 569 +++- pywry/tests/test_auth_flow_integration.py | 254 +- pywry/tests/test_auth_providers.py | 705 ++++- pywry/tests/test_auth_session.py | 174 ++ pywry/tests/test_auth_token_store.py | 275 +- pywry/tests/test_browser_mode_e2e.py | 6 +- pywry/tests/test_chat.py | 318 +-- pywry/tests/test_chat_manager.py | 1680 ++++++++++-- pywry/tests/test_cli.py | 424 +++ pywry/tests/test_config.py | 347 +++ pywry/tests/test_deepagent_provider.py | 1733 ++++++++++-- pywry/tests/test_e2e_deploy_mode.py | 5 +- pywry/tests/test_e2e_rbac_widgets.py | 5 +- pywry/tests/test_error_paths.py | 583 ---- pywry/tests/test_exceptions.py | 18 + pywry/tests/test_freeze.py | 38 + pywry/tests/test_grid.py | 613 +++++ pywry/tests/test_hot_reload.py | 680 ++--- pywry/tests/test_inline_e2e.py | 5 - pywry/tests/test_marquee_e2e.py | 13 +- pywry/tests/test_mcp_app_artifact.py | 301 +- pywry/tests/test_mcp_state_helpers.py | 232 -- pywry/tests/test_mcp_unit.py | 1770 ------------ pywry/tests/test_menu_tray.py | 175 ++ pywry/tests/test_modal.py | 40 + pywry/tests/test_modal_e2e.py | 7 +- pywry/tests/test_models.py | 228 ++ pywry/tests/test_notebook_detection.py | 386 --- pywry/tests/test_scripts.py | 9 + pywry/tests/test_state_memory.py | 556 +++- pywry/tests/test_state_mixins.py | 367 +++ pywry/tests/test_state_redis.py | 839 +++++- pywry/tests/test_state_sqlite.py | 807 ++++-- pywry/tests/test_templates.py | 545 ++++ pywry/tests/test_toolbar.py | 769 ++++++ pywry/tests/test_tvchart.py | 3050 --------------------- pywry/tests/test_tvchart_e2e.py | 10 +- pywry/tests/test_types.py | 210 ++ pywry/tests/test_udf_adapter.py | 562 ---- pywry/tests/test_watcher.py | 587 ++-- pywry/tests/test_widget_protocol.py | 157 ++ pywry/tests/test_window_dispatch.py | 1062 ++++++- pywry/tests/test_window_proxy.py | 617 ++++- 55 files changed, 14077 insertions(+), 8582 deletions(-) delete mode 100644 pywry/tests/test_error_paths.py delete mode 100644 pywry/tests/test_mcp_state_helpers.py delete mode 100644 pywry/tests/test_mcp_unit.py delete mode 100644 pywry/tests/test_notebook_detection.py delete mode 100644 pywry/tests/test_tvchart.py delete mode 100644 pywry/tests/test_udf_adapter.py diff --git a/pywry/pywry/chat/providers/deepagent.py b/pywry/pywry/chat/providers/deepagent.py index d65f991..93fd42e 100644 --- a/pywry/pywry/chat/providers/deepagent.py +++ b/pywry/pywry/chat/providers/deepagent.py @@ -161,65 +161,53 @@ def _step_in_call(self, ch: str) -> None: self._in_string = False self._escape = False - def _step_in_special(self, ch: str, out: list[str]) -> None: - """Advance the ``<|...|>`` state machine; recurse on tail after ``|>``.""" + def _step_in_special(self, ch: str, _out: list[str]) -> None: + """Advance the ``<|...|>`` state machine; close when ``|>`` arrives. + + Because ``feed()`` drives one character at a time and ``_in_special`` + is entered with an empty buffer, the close marker is always at the + tail of the buffer — there is no trailing text to recurse on. + """ self._buffer += ch - close_idx = self._buffer.find(self._SPECIAL_CLOSE) - if close_idx < 0: + if self._SPECIAL_CLOSE not in self._buffer: return - rest = self._buffer[close_idx + len(self._SPECIAL_CLOSE) :] self._buffer = "" self._in_special = False - if rest: - out.append(self.feed(rest)) - def _try_open_call(self, out: list[str]) -> bool: + def _try_open_call(self, _out: list[str]) -> bool: """If a complete ``functions....{`` opener sits in buffer, enter call mode. Returns True if the buffer was consumed (caller skips other checks); - False if the marker isn't fully present yet — caller must NOT keep - scanning the buffer for ``<|`` (the ``functions.`` prefix already - committed us to wait). + False if the marker isn't fully present yet. ``_flush_safe_prefix`` + guarantees ``functions.`` always sits at the buffer head when it's + present, and char-by-char feeding means ``{`` is always the tail — + no leading prefix to emit and no trailing text to recurse on. """ - call_idx = self._buffer.find(self._CALL_START) - if call_idx < 0: + if self._CALL_START not in self._buffer: return False - brace_idx = self._buffer.find("{", call_idx + len(self._CALL_START)) + brace_idx = self._buffer.find("{", len(self._CALL_START)) if brace_idx < 0: # Marker present but no ``{`` yet — keep buffering, do not # fall through to the ``<|`` check (it would never match # ``functions.`` and we'd over-emit). return True - if call_idx > 0: - out.append(self._buffer[:call_idx]) - rest = self._buffer[brace_idx + 1 :] self._buffer = "" self._in_call = True self._depth = 1 self._in_string = False self._escape = False - if rest: - out.append(self.feed(rest)) return True - def _try_open_special(self, out: list[str]) -> bool: - """If a ``<|...|>`` token (or its open) is in buffer, drop it; return True.""" - special_idx = self._buffer.find(self._SPECIAL_OPEN) - if special_idx < 0: + def _try_open_special(self, _out: list[str]) -> bool: + """If a ``<|`` opener sits in buffer, drop it and enter skip mode. + + ``_flush_safe_prefix`` guarantees only ``<|`` itself (no trailing + text) ever reaches us, and the closing ``|>`` is consumed later + by ``_step_in_special`` — so we only need to handle the "open + seen, no close yet" case. + """ + if self._SPECIAL_OPEN not in self._buffer: return False - close_idx = self._buffer.find(self._SPECIAL_CLOSE, special_idx + len(self._SPECIAL_OPEN)) - if close_idx >= 0: - if special_idx > 0: - out.append(self._buffer[:special_idx]) - rest = self._buffer[close_idx + len(self._SPECIAL_CLOSE) :] - self._buffer = "" - if rest: - out.append(self.feed(rest)) - return True - # Open seen but no close yet — drop everything from ``<|`` on, - # emit the prefix, enter token-skip mode. - if special_idx > 0: - out.append(self._buffer[:special_idx]) self._buffer = "" self._in_special = True return True diff --git a/pywry/pywry/cli.py b/pywry/pywry/cli.py index 65bafb9..1d85164 100644 --- a/pywry/pywry/cli.py +++ b/pywry/pywry/cli.py @@ -447,9 +447,6 @@ def show_config_sources() -> int: if forced_status is True: status = "✓ Active" path_display = "" - elif forced_status is False: - status = "✗ Not found" - path_display = path_str # Check if file exists elif name == "Environment variables": import os diff --git a/pywry/pywry/config.py b/pywry/pywry/config.py index ed4911c..42dbf6e 100644 --- a/pywry/pywry/config.py +++ b/pywry/pywry/config.py @@ -27,7 +27,7 @@ if sys.version_info >= (3, 11): import tomllib -else: +else: # pragma: no cover - python 3.10 fallback; cannot be exercised on 3.11+ try: import tomli as tomllib except ImportError: diff --git a/pywry/pywry/inline.py b/pywry/pywry/inline.py index ecf3e28..0569702 100644 --- a/pywry/pywry/inline.py +++ b/pywry/pywry/inline.py @@ -79,16 +79,11 @@ def _get_default_theme() -> ThemeLiteral: return "system" if is_headless() else "dark" -try: - import uvicorn - - from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect - from fastapi.middleware.cors import CORSMiddleware - from fastapi.responses import HTMLResponse, Response +import uvicorn - HAS_FASTAPI = True -except ImportError: - HAS_FASTAPI = False +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, Response try: from ipywidgets import Output @@ -1687,8 +1682,6 @@ def __init__( token: str | None = None, ) -> None: super().__init__() - if not HAS_FASTAPI: - raise ImportError("fastapi and uvicorn required: pip install fastapi uvicorn") # For browser_only mode, we don't need IPython (just the server + browser) self._browser_only = browser_only @@ -3330,8 +3323,6 @@ def generate_dataframe_html( } if grid_options: grid_config.update(grid_options) - if "rowData" not in grid_config: - grid_config["rowData"] = row_data assets = _build_aggrid_assets(aggrid_theme, theme_mode) # For system theme, default to dark AG Grid theme (JS will switch) @@ -3773,6 +3764,188 @@ def _preload_chart_data(user_id: str = "default") -> dict[str, str]: return preload +def generate_tvchart_html( + chart_html: str, + config_payload: str, + chart_id: str, + widget_id: str, + title: str = "Chart", + theme: ThemeLiteral | None = None, + toolbars: list[dict[str, Any] | Toolbar] | None = None, + modals: list[dict[str, Any] | Modal] | None = None, + inline_css: str = "", + full_document: bool = True, + token: str | None = None, +) -> str: + """Generate HTML for a TradingView Lightweight Chart. + + Parameters + ---------- + chart_html : str + The chart container ``
`` (and any toolbar/modal markup). + config_payload : str + JSON string with ``chartOptions``, ``series``, ``storage``, etc. + chart_id : str + DOM id of the chart container element. + widget_id : str + Unique widget identifier (used by the pywry bridge). + title : str + Page title. + theme : 'dark' or 'light', optional + Color theme. + toolbars : list, optional + Toolbar configurations. + modals : list, optional + Modal configurations. + inline_css : str + Extra CSS to inject. + full_document : bool + If True, return complete HTML document; if False, content fragment only. + token : str or None + Widget auth token for the pywry bridge. + + Returns + ------- + str + """ + from .assets import ( + get_pywry_css, + get_scrollbar_js, + get_toast_css, + get_tvchart_defaults_js, + get_tvchart_js, + ) + from .modal import wrap_content_with_modals + from .notebook import _wrap_content_with_toolbars + + if theme is None: + theme = _get_default_theme() + + tvchart_js = get_tvchart_js() + tvchart_script = f"" if tvchart_js else "" + tvchart_defaults = get_tvchart_defaults_js() + tvchart_defaults_script = f"" if tvchart_defaults else "" + + # Chart init script — waits for LightweightCharts then renders + chart_init_script = f"""""" + + if not full_document: + # Content fragment for anywidget — caller handles wrapping + wrapped = _wrap_content_with_toolbars(chart_html, toolbars) + if modals: + modal_html, modal_scripts = wrap_content_with_modals("", modals) + wrapped = f"{wrapped}{modal_html}{modal_scripts}" + return f"{wrapped}\n{chart_init_script}" + + # Full document for IFrame / browser mode + pywry_css = get_pywry_css() + pywry_style = f"" if pywry_css else "" + toast_css = get_toast_css() + toast_style = f"" if toast_css else "" + scrollbar_js = get_scrollbar_js() + scrollbar_script = f"" if scrollbar_js else "" + inline_style = f"" if inline_css else "" + + if theme == "dark": + widget_theme_class = "pywry-theme-dark" + elif theme == "system": + widget_theme_class = "pywry-theme-system" + else: + widget_theme_class = "pywry-theme-light" + + # Build widget content with toolbars + widget_content = wrap_content_with_toolbars(chart_html, toolbars) + + # Inject modals + modal_block = "" + if modals: + modal_html, modal_scripts = wrap_content_with_modals("", modals) + modal_block = f"{modal_html}{modal_scripts}" + + bridge_js = _get_pywry_bridge_js(widget_id, token) + + return f""" + + + + {title} + {pywry_style} + {toast_style} + {inline_style} + {scrollbar_script} + {bridge_js} + {tvchart_script} + {tvchart_defaults_script} + + + +
+ {widget_content} +
+ {modal_block} + {chart_init_script} + +""" + + def show_tvchart( data: Any = None, callbacks: dict[str, Callable[..., Any]] | None = None, diff --git a/pywry/pywry/toolbar.py b/pywry/pywry/toolbar.py index e388d92..1478b86 100644 --- a/pywry/pywry/toolbar.py +++ b/pywry/pywry/toolbar.py @@ -1113,15 +1113,12 @@ def build_html(self) -> str: f'onclick="{toggle_script}">{_EYE_ICON_SVG}' ) - if buttons_html: - input_wrapper = ( - f'' - f"{input_html}" - f'{buttons_html}' - f"" - ) - else: - input_wrapper = input_html + input_wrapper = ( + f'' + f"{input_html}" + f'{buttons_html}' + f"" + ) if self.label: return ( diff --git a/pywry/pywry/tvchart/normalize.py b/pywry/pywry/tvchart/normalize.py index bb91573..37aaca7 100644 --- a/pywry/pywry/tvchart/normalize.py +++ b/pywry/pywry/tvchart/normalize.py @@ -255,8 +255,6 @@ def _detect_symbol_column( # noqa: C901 for col in columns: if col not in _SYMBOL_ALIASES: continue - if col in _ALL_OHLCV_ALIASES: - continue if hasattr(data, "__getitem__") and hasattr(data, "__len__"): try: col_data = data[col] diff --git a/pywry/pywry/tvchart/toolbars.py b/pywry/pywry/tvchart/toolbars.py index 9ec07c9..675e5e7 100644 --- a/pywry/pywry/tvchart/toolbars.py +++ b/pywry/pywry/tvchart/toolbars.py @@ -422,9 +422,6 @@ def _time_range_presets(intervals: list[str] | None = None) -> tuple[list[Any], if value in {"all", "ytd"} or (span_lookup[value] / finest_days) >= 3 ] - if not preferred: - preferred = candidates[-3:] - selected = next( ( candidate diff --git a/pywry/tests/conftest.py b/pywry/tests/conftest.py index 5122b63..034382b 100644 --- a/pywry/tests/conftest.py +++ b/pywry/tests/conftest.py @@ -11,6 +11,17 @@ from pathlib import Path from typing import TYPE_CHECKING, Any +# Pre-import pydantic.root_model and beartype.claw._clawstate to work +# around a Pydantic + beartype + coverage interaction that breaks test +# collection when both packages are involved (e.g. anything importing +# mcp.types). Keep these imports above pytest. +import pydantic.root_model # noqa: F401 + +try: + import beartype.claw._clawstate # noqa: F401 +except ImportError: + pass + import pytest from tests.constants import ( @@ -868,3 +879,69 @@ def auth_session_manager(mock_oauth_provider, memory_token_store): token_store=memory_token_store, session_key="test_user", ) + + +# ============================================================================= +# MCP Test Fixtures +# ============================================================================= + + +@pytest.fixture +def mcp_fresh_state(): + """Reset all MCP global state before and after each test. + + Clears the singleton app, widget registry, widget configs, pending + responses, pending events, and the server-side events bucket. + """ + from pywry.mcp import state as mcp_state + from pywry.mcp.server import _events + + mcp_state._app = None + mcp_state._widgets.clear() + mcp_state._widget_configs.clear() + mcp_state._pending_responses.clear() + mcp_state._pending_events.clear() + _events.clear() + yield + mcp_state._app = None + mcp_state._widgets.clear() + mcp_state._widget_configs.clear() + mcp_state._pending_responses.clear() + mcp_state._pending_events.clear() + _events.clear() + + +@pytest.fixture +def mcp_widget(mcp_fresh_state): + """Register a single mock widget under id ``w``. + + Depends on ``mcp_fresh_state`` so the registry is clean. + """ + from unittest.mock import MagicMock + + from pywry.mcp import state as mcp_state + + widget = MagicMock() + widget.widget_id = "w" + mcp_state._widgets["w"] = widget + yield widget + + +def make_handler_ctx( + args: dict[str, Any], + headless: bool = False, + events: dict | None = None, +): + """Build a HandlerContext for unit-testing MCP handlers. + + The ``make_callback`` is a no-op so tests can focus on the handler + contract (widget.emit calls + return dict). + """ + from pywry.mcp.handlers import HandlerContext + + return HandlerContext( + args=args, + events=events if events is not None else {}, + make_callback=lambda _wid: lambda *_a, **_kw: None, + headless=headless, + ) diff --git a/pywry/tests/test_alerts.py b/pywry/tests/test_alerts.py index d11ba2a..82e4129 100644 --- a/pywry/tests/test_alerts.py +++ b/pywry/tests/test_alerts.py @@ -1015,13 +1015,6 @@ def test_pywry_alert_event_triggers_toast(self) -> None: # ============================================================================= -try: - from pywry.inline import HAS_FASTAPI -except ImportError: - HAS_FASTAPI = False - - -@pytest.mark.skipif(not HAS_FASTAPI, reason="FastAPI not installed") class TestInlineAlertE2E: """E2E tests for alerts in inline/notebook rendering.""" diff --git a/pywry/tests/test_asset_loader.py b/pywry/tests/test_asset_loader.py index f23872f..847288f 100644 --- a/pywry/tests/test_asset_loader.py +++ b/pywry/tests/test_asset_loader.py @@ -250,3 +250,174 @@ def test_tracks_file_hashes(self, tmp_path): loader.load_css("tracked.css") resolved = loader.resolve_path("tracked.css") assert resolved in loader._hash_cache + + +class TestLoadCssFromCache: + def test_returns_cached_value(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("x.css") + css.write_text("b{}") + assert loader.load_css("x.css", use_cache=True) == "a{}" + + +class TestLoadScriptExtended: + def test_cached_value(self, tmp_path): + js = tmp_path / "x.js" + js.write_text("var x;") + loader = AssetLoader(base_dir=tmp_path) + loader.load_script("x.js") + js.write_text("var y;") + assert loader.load_script("x.js", use_cache=True) == "var x;" + + +class TestLoadAllCssAndScripts: + def test_load_all_css_concatenates(self, tmp_path): + a = tmp_path / "a.css" + b = tmp_path / "b.css" + a.write_text("a{}") + b.write_text("b{}") + loader = AssetLoader(base_dir=tmp_path) + result = loader.load_all_css(["a.css", "b.css"]) + assert "a{}" in result + assert "b{}" in result + assert "Source: a.css" in result + + def test_load_all_css_skips_empty(self, tmp_path): + a = tmp_path / "a.css" + a.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + result = loader.load_all_css(["a.css", "missing.css"]) + assert result.count("Source:") == 1 + + def test_load_all_scripts(self, tmp_path): + a = tmp_path / "a.js" + b = tmp_path / "b.js" + a.write_text("scriptA;") + b.write_text("scriptB;") + loader = AssetLoader(base_dir=tmp_path) + result = loader.load_all_scripts(["a.js", "b.js"]) + assert len(result) == 2 + assert "scriptA;" in result[0] + assert "scriptB;" in result[1] + + +class TestGetAssetIdExtended: + def test_returns_id_format(self, tmp_path): + css = tmp_path / "my-style.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("my-style.css") + asset_id = loader.get_asset_id("my-style.css") + assert asset_id.startswith("pywry-css-my-style-") + + def test_loads_file_when_no_hash_yet(self, tmp_path): + css = tmp_path / "fresh.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + asset_id = loader.get_asset_id("fresh.css") + assert asset_id.startswith("pywry-css-fresh-") + + def test_handles_missing_file_with_unknown_hash(self, tmp_path): + loader = AssetLoader(base_dir=tmp_path) + asset_id = loader.get_asset_id("missing.css") + assert "unknown" in asset_id + + +class TestHasChangedExtended: + def test_returns_true_when_no_old_hash(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + assert loader.has_changed("x.css") is True + + def test_returns_false_when_unchanged(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("x.css") + assert loader.has_changed("x.css") is False + + def test_returns_true_after_modification(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("x.css") + css.write_text("b{}") + assert loader.has_changed("x.css") is True + + def test_returns_true_on_read_error(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("x.css") + css.unlink() + assert loader.has_changed("x.css") is True + + +class TestInvalidateAndClear: + def test_invalidate_removes_cache(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("x.css") + assert any(loader._cache) + loader.invalidate("x.css") + assert not any(loader._cache) + assert any(loader._hash_cache) + + def test_clear_cache_removes_both(self, tmp_path): + css = tmp_path / "x.css" + css.write_text("a{}") + loader = AssetLoader(base_dir=tmp_path) + loader.load_css("x.css") + loader.clear_cache() + assert not any(loader._cache) + assert not any(loader._hash_cache) + + +class TestGlobalLoaderHelpers: + def test_get_asset_loader_singleton(self): + from pywry.asset_loader import get_asset_loader + + a = get_asset_loader() + b = get_asset_loader() + assert a is b + + def test_configure_with_base_dir(self, tmp_path): + from pywry.asset_loader import configure_asset_loader + + loader = configure_asset_loader(base_dir=tmp_path) + assert loader.base_dir == tmp_path + + def test_configure_with_settings_path(self, tmp_path): + from pywry.asset_loader import configure_asset_loader + from pywry.config import AssetSettings + + settings = AssetSettings(path=str(tmp_path)) + loader = configure_asset_loader(settings=settings) + assert loader.base_dir == tmp_path + + +class TestBuildTags: + def test_build_style_tag(self): + from pywry.asset_loader import build_style_tag + + tag = build_style_tag("body { color: red; }", "my-id") + assert 'id="my-id"' in tag + assert " None: assert "'none'" in csp or "style-src" in csp finally: server.stop() + + def test_unknown_path_returns_404(self) -> None: + """Non-/, non-/callback paths return 404.""" + from urllib.error import HTTPError + + server = OAuthCallbackServer() + server.start() + + try: + url = f"http://127.0.0.1:{server._actual_port}/unknown-path" + with pytest.raises(HTTPError) as exc_info: + urlopen(url, timeout=5) + assert exc_info.value.code == 404 + finally: + server.stop() + + def test_second_callback_returns_success_html(self) -> None: + """A second callback after the first one fired returns the success page.""" + server = OAuthCallbackServer() + server.start() + + try: + # Pre-set the result_event manually so that the handler's + # _shutdown_server() is NOT triggered by our request. + # This way, the server stays alive for the second branch. + server._result = { + "code": "preset", + "state": "preset_state", + "error": None, + "error_description": None, + } + server._result_event.set() + + # Now send a callback request — since result_event is set, the + # handler takes the "already captured" branch. + params = urlencode({"code": "ignored", "state": "ignored"}) + resp = urlopen(f"{server.redirect_uri}?{params}", timeout=5) + assert resp.status == 200 + content = resp.read().decode("utf-8") + assert "Authentication Complete" in content + + # Result is unchanged — second callback does not overwrite first + assert server._result is not None + assert server._result["code"] == "preset" + finally: + server.stop() + + def test_stop_joins_thread(self) -> None: + """stop() joins the running thread instead of leaking it.""" + server = OAuthCallbackServer() + # Build a placeholder server thread that is alive but not necessarily + # a real HTTP server, so we can exercise the join branch reliably. + stop_event = threading.Event() + + def _busy_loop() -> None: + stop_event.wait(timeout=10) + + server._thread = threading.Thread(target=_busy_loop, daemon=True) + server._thread.start() + + # Mock the _server so server.shutdown() unblocks _busy_loop + fake_server = MagicMock() + fake_server.shutdown = stop_event.set + server._server = fake_server # type: ignore[assignment] + + assert server._thread.is_alive() + server.stop() + assert server._thread is None + assert server._server is None diff --git a/pywry/tests/test_auth_deploy_routes.py b/pywry/tests/test_auth_deploy_routes.py index 3e04835..04ff859 100644 --- a/pywry/tests/test_auth_deploy_routes.py +++ b/pywry/tests/test_auth_deploy_routes.py @@ -2,16 +2,23 @@ from __future__ import annotations +import asyncio import time +from typing import Any from unittest.mock import AsyncMock, MagicMock +import pytest + from fastapi import FastAPI from fastapi.testclient import TestClient from pywry.auth.deploy_routes import ( + AuthStateStore, + LoginRateLimiter, _login_rate_limiter, _pending_auth_states, + _verify_csrf_origin, cleanup_expired_states, create_auth_router, ) @@ -51,12 +58,19 @@ def _make_mock_provider() -> MagicMock: return provider -def _make_mock_deploy_settings() -> MagicMock: +def _make_mock_deploy_settings( + *, + admin_users: list[str] | None = None, + force_https: bool = False, + auth_redirect_uri: str = "", +) -> MagicMock: """Create a mock DeploySettings.""" settings = MagicMock() settings.auth_session_cookie = "pywry_session" settings.default_roles = ["viewer"] - settings.admin_users = ["admin@test.com"] + settings.admin_users = ["admin@test.com"] if admin_users is None else admin_users + settings.force_https = force_https + settings.auth_redirect_uri = auth_redirect_uri return settings @@ -69,36 +83,70 @@ def _make_mock_session_store() -> MagicMock: return store -def _create_test_app( +def _build_app( + *, provider: MagicMock | None = None, session_store: MagicMock | None = None, token_store: MemoryTokenStore | None = None, deploy_settings: MagicMock | None = None, auth_config: AuthConfig | None = None, -) -> FastAPI: - """Create a FastAPI app with the auth router mounted.""" + inject_session: Any | None = None, +) -> tuple[FastAPI, dict[str, Any]]: + """Create a FastAPI app with the auth router mounted. + + Returns ``(app, deps)`` where ``deps`` exposes the wiring dependencies for + post-call assertions. If ``inject_session`` is provided, a middleware + attaches it to ``request.state.session`` to simulate the auth middleware. + """ app = FastAPI() - provider = provider or _make_mock_provider() - session_store = session_store or _make_mock_session_store() - token_store = token_store or MemoryTokenStore() - deploy_settings = deploy_settings or _make_mock_deploy_settings() - auth_config = auth_config or AuthConfig( - enabled=True, - token_secret="test-secret-key-for-testing", - session_ttl=3600, - ) + deps: dict[str, Any] = { + "provider": provider or _make_mock_provider(), + "session_store": session_store or _make_mock_session_store(), + "token_store": token_store or MemoryTokenStore(), + "deploy_settings": deploy_settings or _make_mock_deploy_settings(), + "auth_config": auth_config + or AuthConfig( + enabled=True, + token_secret="test-secret-key-for-testing", + session_ttl=3600, + ), + } + + if inject_session is not None: - router = create_auth_router( + @app.middleware("http") + async def add_session(request, call_next): + request.state.session = inject_session + return await call_next(request) + + router = create_auth_router(**deps) + app.include_router(router) + return app, deps + + +def _create_test_app( + provider: MagicMock | None = None, + session_store: MagicMock | None = None, + token_store: MemoryTokenStore | None = None, + deploy_settings: MagicMock | None = None, + auth_config: AuthConfig | None = None, +) -> FastAPI: + """Compat shim that returns just the FastAPI app (without deps).""" + app, _ = _build_app( provider=provider, session_store=session_store, token_store=token_store, deploy_settings=deploy_settings, auth_config=auth_config, ) - app.include_router(router) return app +def _run(coro: Any) -> Any: + """Synchronously drive a coroutine to completion.""" + return asyncio.run(coro) + + # ── Tests ──────────────────────────────────────────────────────────── @@ -421,3 +469,492 @@ def test_login_within_limit(self) -> None: assert resp.status_code == 302 _login_rate_limiter.reset() + + +# ── Unit tests: _verify_csrf_origin ───────────────────────────────── + + +def _make_csrf_request( + *, + origin: str | None, + referer: str | None, + scheme: str = "http", + host: str = "testserver", +) -> MagicMock: + """Build a mock Request for _verify_csrf_origin coverage.""" + request = MagicMock() + headers_dict: dict[str, str] = {} + if origin is not None: + headers_dict["origin"] = origin + if referer is not None: + headers_dict["referer"] = referer + + request.headers = MagicMock() + request.headers.get = lambda key, default=None: headers_dict.get(key, default) + request.url.scheme = scheme + request.url.netloc = host + return request + + +class TestVerifyCSRFOrigin: + """Cover branches in _verify_csrf_origin().""" + + def test_referer_used_when_no_origin(self) -> None: + """Referer is used as fallback when Origin is absent.""" + request = _make_csrf_request( + origin=None, + referer="http://testserver/some/path", + ) + assert _verify_csrf_origin(request) is True + + def test_referer_with_invalid_url(self) -> None: + """Invalid Referer (no scheme/netloc) is rejected.""" + request = _make_csrf_request(origin=None, referer="/relative/path") + assert _verify_csrf_origin(request) is False + + def test_origin_null_falls_back_to_referer(self) -> None: + """Origin='null' is treated as missing and falls back to Referer.""" + request = _make_csrf_request( + origin="null", + referer="http://testserver/foo", + ) + assert _verify_csrf_origin(request) is True + + def test_origin_null_no_referer_rejected(self) -> None: + """Origin='null' with no Referer is rejected.""" + request = _make_csrf_request(origin="null", referer=None) + assert _verify_csrf_origin(request) is False + + def test_trusted_origins_allowed(self) -> None: + """Origin in trusted_origins is accepted.""" + request = _make_csrf_request( + origin="https://approved.example.com", + referer=None, + ) + assert ( + _verify_csrf_origin( + request, + trusted_origins=["https://approved.example.com/"], + ) + is True + ) + + def test_trusted_origins_rejected(self) -> None: + """Origin not in trusted_origins is rejected.""" + request = _make_csrf_request( + origin="https://unknown.example.com", + referer=None, + ) + assert ( + _verify_csrf_origin( + request, + trusted_origins=["https://approved.example.com"], + ) + is False + ) + + +# ── Unit tests: LoginRateLimiter ──────────────────────────────────── + + +class TestLoginRateLimiterEviction: + """Cover the eviction branch in LoginRateLimiter.""" + + def test_old_entries_evicted(self) -> None: + """Old entries fall outside the window and are popped, freeing the slot.""" + limiter = LoginRateLimiter(max_requests=2, window_seconds=0.05) + assert limiter.is_allowed("1.1.1.1") is True + assert limiter.is_allowed("1.1.1.1") is True + # Now exhausted + assert limiter.is_allowed("1.1.1.1") is False + # Wait for window to pass + time.sleep(0.1) + # Old entries should be evicted + assert limiter.is_allowed("1.1.1.1") is True + + +# ── Unit tests: AuthStateStore ────────────────────────────────────── + + +class TestAuthStateStoreInternals: + """Cover internal AuthStateStore branches.""" + + def test_eviction_on_capacity(self) -> None: + """When at capacity, oldest entry is evicted on put().""" + store = AuthStateStore(max_pending=2, max_age=600.0) + now = time.time() + _run(store.put("a", {"created_at": now - 10, "value": "A"})) + _run(store.put("b", {"created_at": now - 5, "value": "B"})) + # Adding a third should evict 'a' (oldest) + _run(store.put("c", {"created_at": now, "value": "C"})) + assert _run(store.contains("a")) is False + assert _run(store.contains("b")) is True + assert _run(store.contains("c")) is True + + def test_evict_expired_internal(self) -> None: + """_evict_expired removes entries older than max_age.""" + store = AuthStateStore(max_pending=10, max_age=1.0) + store._store["old"] = {"created_at": time.time() - 100} + store._store["fresh"] = {"created_at": time.time()} + store._evict_expired() + assert "old" not in store._store + assert "fresh" in store._store + + def test_cleanup_returns_count(self) -> None: + """cleanup() returns the number of expired entries removed.""" + store = AuthStateStore(max_pending=10, max_age=1.0) + # Pre-populate manually because put() itself runs _evict_expired. + store._store["expired1"] = {"created_at": time.time() - 100} + store._store["expired2"] = {"created_at": time.time() - 100} + store._store["fresh"] = {"created_at": time.time()} + removed = _run(store.cleanup()) + assert removed == 2 + + def test_size_returns_count(self) -> None: + """size() returns the current number of pending states.""" + store = AuthStateStore(max_pending=10, max_age=600.0) + assert _run(store.size()) == 0 + _run(store.put("a", {"created_at": time.time()})) + assert _run(store.size()) == 1 + + +# ── /auth/login extra branches ────────────────────────────────────── + + +class TestLoginExtraBranches: + """Cover extra branches in /auth/login (configured URI, force_https).""" + + def test_login_uses_configured_redirect_uri(self) -> None: + """If auth_redirect_uri is configured, it overrides the request-derived URI.""" + _login_rate_limiter.reset() + _pending_auth_states.clear() + deploy = _make_mock_deploy_settings( + auth_redirect_uri="https://my-app.example.com/auth/callback" + ) + app, deps = _build_app(deploy_settings=deploy) + client = TestClient(app, follow_redirects=False) + resp = client.get("/auth/login") + assert resp.status_code == 302 + deps["provider"].build_authorize_url.assert_called_once() + call_kwargs = deps["provider"].build_authorize_url.call_args[1] + assert call_kwargs["redirect_uri"] == "https://my-app.example.com/auth/callback" + _login_rate_limiter.reset() + + def test_login_force_https_rewrites_uri(self) -> None: + """force_https rewrites http:// to https:// for non-localhost hosts.""" + _login_rate_limiter.reset() + _pending_auth_states.clear() + deploy = _make_mock_deploy_settings( + force_https=True, + auth_redirect_uri="http://example.com/auth/callback", + ) + app, deps = _build_app(deploy_settings=deploy) + client = TestClient(app, follow_redirects=False) + resp = client.get("/auth/login") + assert resp.status_code == 302 + call_kwargs = deps["provider"].build_authorize_url.call_args[1] + assert call_kwargs["redirect_uri"].startswith("https://example.com/") + _login_rate_limiter.reset() + + def test_login_force_https_skips_localhost(self) -> None: + """force_https leaves localhost http:// untouched.""" + _login_rate_limiter.reset() + _pending_auth_states.clear() + deploy = _make_mock_deploy_settings( + force_https=True, + auth_redirect_uri="http://localhost:8080/auth/callback", + ) + app, deps = _build_app(deploy_settings=deploy) + client = TestClient(app, follow_redirects=False) + resp = client.get("/auth/login") + assert resp.status_code == 302 + call_kwargs = deps["provider"].build_authorize_url.call_args[1] + assert call_kwargs["redirect_uri"].startswith("http://localhost") + _login_rate_limiter.reset() + + +# ── /auth/callback extra branches ─────────────────────────────────── + + +class TestCallbackPopBranch: + """Cover the auth_state pop returning None branch.""" + + def test_callback_pop_returns_none(self) -> None: + """If auth state is removed between contains() and pop(), 400 returned.""" + _pending_auth_states.clear() + _pending_auth_states["state1"] = { + "pkce_verifier": None, + "redirect_uri": "http://testserver/auth/callback", + "nonce": "nonce", + "created_at": time.time(), + } + # Patch AuthStateStore.pop to return None even though contains() returns True + from pywry.auth import deploy_routes as dr + + original_pop = dr._auth_state_store.pop + + async def patched_pop(_state: str) -> dict | None: + return None + + dr._auth_state_store.pop = patched_pop # type: ignore[assignment] + try: + app = _create_test_app() + client = TestClient(app, follow_redirects=False) + resp = client.get("/auth/callback?code=c&state=state1") + assert resp.status_code == 400 + assert resp.json()["error"] == "invalid_state" + assert "consumed" in resp.json()["error_description"] + finally: + dr._auth_state_store.pop = original_pop # type: ignore[assignment] + + +class TestCallbackUserInfoFailure: + """Cover the get_userinfo exception path in callback handler.""" + + def test_user_info_failure_continues(self) -> None: + """If get_userinfo raises, callback still creates session with 'unknown' id.""" + _pending_auth_states.clear() + _pending_auth_states["s1"] = { + "pkce_verifier": None, + "redirect_uri": "http://testserver/auth/callback", + "nonce": "nonce", + "created_at": time.time(), + } + + provider = _make_mock_provider() + provider.get_userinfo = AsyncMock(side_effect=RuntimeError("downstream fail")) + + session_store = _make_mock_session_store() + app = _create_test_app(provider=provider, session_store=session_store) + client = TestClient(app, follow_redirects=False) + resp = client.get("/auth/callback?code=c&state=s1") + + assert resp.status_code == 302 + # Session was created with 'unknown' user_id since get_userinfo failed + session_store.create_session.assert_called_once() + call_kwargs = session_store.create_session.call_args[1] + assert call_kwargs["user_id"] == "unknown" + + +class TestCallbackHTTPSCookie: + """Cover the cookie_secure branches in /auth/callback.""" + + def test_cookie_secure_when_force_https(self) -> None: + """force_https forces Secure=True on the session cookie.""" + _pending_auth_states.clear() + _pending_auth_states["s1"] = { + "pkce_verifier": None, + "redirect_uri": "https://example.com/auth/callback", + "nonce": "n", + "created_at": time.time(), + } + deploy = _make_mock_deploy_settings(force_https=True) + app = _create_test_app(deploy_settings=deploy) + client = TestClient(app, follow_redirects=False) + resp = client.get("/auth/callback?code=c&state=s1") + assert resp.status_code == 302 + cookie_header = resp.headers.get("set-cookie", "") + assert "Secure" in cookie_header + + +# ── /auth/refresh authenticated body ──────────────────────────────── + + +def _mock_session(session_id: str = "sess_xyz") -> MagicMock: + """Build a mock UserSession with sensible defaults.""" + return MagicMock(session_id=session_id, user_id="u1", roles=["viewer"]) + + +class TestRefreshAuthenticated: + """Cover the body of /auth/refresh when authenticated.""" + + def test_refresh_no_existing_tokens(self) -> None: + """No stored tokens for the session → 400 no_refresh_token.""" + token_store = MemoryTokenStore() + app, _ = _build_app(token_store=token_store, inject_session=_mock_session()) + client = TestClient(app) + resp = client.post( + "/auth/refresh", + headers={"origin": "http://testserver"}, + ) + assert resp.status_code == 400 + assert resp.json()["error"] == "no_refresh_token" + + def test_refresh_no_refresh_token_field(self) -> None: + """Stored tokens but no refresh_token → 400 no_refresh_token.""" + token_store = MemoryTokenStore() + _run( + token_store.save( + "sess_xyz", + OAuthTokenSet(access_token="at", refresh_token=None, expires_in=3600), + ) + ) + app, _ = _build_app(token_store=token_store, inject_session=_mock_session()) + client = TestClient(app) + resp = client.post( + "/auth/refresh", + headers={"origin": "http://testserver"}, + ) + assert resp.status_code == 400 + + def test_refresh_success(self) -> None: + """Valid refresh returns new token info.""" + token_store = MemoryTokenStore() + _run( + token_store.save( + "sess_xyz", + OAuthTokenSet( + access_token="at_old", + refresh_token="rt_old", + expires_in=3600, + ), + ) + ) + app, _ = _build_app(token_store=token_store, inject_session=_mock_session()) + client = TestClient(app) + resp = client.post( + "/auth/refresh", + headers={"origin": "http://testserver"}, + ) + assert resp.status_code == 200 + body = resp.json() + assert body["success"] is True + assert body["expires_in"] == 3600 + + def test_refresh_provider_failure(self) -> None: + """Provider refresh failure → 500 with sanitised error message.""" + token_store = MemoryTokenStore() + _run( + token_store.save( + "sess_xyz", + OAuthTokenSet( + access_token="at_old", + refresh_token="rt_old", + expires_in=3600, + ), + ) + ) + provider = _make_mock_provider() + provider.refresh_tokens = AsyncMock(side_effect=RuntimeError("broken")) + + app, _ = _build_app( + provider=provider, + token_store=token_store, + inject_session=_mock_session(), + ) + client = TestClient(app) + resp = client.post( + "/auth/refresh", + headers={"origin": "http://testserver"}, + ) + assert resp.status_code == 500 + assert resp.json()["error"] == "refresh_failed" + + +# ── /auth/logout authenticated body ───────────────────────────────── + + +class TestLogoutAuthenticated: + """Cover the body of /auth/logout when authenticated.""" + + def test_logout_with_session_revokes(self) -> None: + """Logout with active session revokes tokens and clears state.""" + token_store = MemoryTokenStore() + _run( + token_store.save( + "sess_xyz", + OAuthTokenSet(access_token="at", refresh_token="rt", expires_in=3600), + ) + ) + provider = _make_mock_provider() + session_store = _make_mock_session_store() + + app, _ = _build_app( + provider=provider, + token_store=token_store, + session_store=session_store, + inject_session=_mock_session(), + ) + client = TestClient(app) + resp = client.post( + "/auth/logout", + headers={"origin": "http://testserver"}, + ) + assert resp.status_code == 200 + provider.revoke_token.assert_awaited_once() + session_store.delete_session.assert_awaited_once_with("sess_xyz") + # The session-scoped tokens should be gone after logout + assert _run(token_store.load("sess_xyz")) is None + + def test_logout_revoke_swallows_exception(self) -> None: + """If revoke_token raises, logout still succeeds.""" + token_store = MemoryTokenStore() + _run( + token_store.save( + "sess_xyz", + OAuthTokenSet(access_token="at", refresh_token="rt", expires_in=3600), + ) + ) + provider = _make_mock_provider() + provider.revoke_token = AsyncMock(side_effect=RuntimeError("boom")) + session_store = _make_mock_session_store() + + app, _ = _build_app( + provider=provider, + token_store=token_store, + session_store=session_store, + inject_session=_mock_session(), + ) + client = TestClient(app) + resp = client.post( + "/auth/logout", + headers={"origin": "http://testserver"}, + ) + # Logout still returns success despite revoke failure + assert resp.status_code == 200 + + +# ── /auth/userinfo and /auth/status authenticated bodies ──────────── + + +class TestUserinfoAuthenticated: + """Cover the userinfo body when authenticated.""" + + def test_userinfo_authenticated(self) -> None: + """Authenticated userinfo returns session details from the session metadata.""" + session = MagicMock( + session_id="s1", + user_id="u1", + roles=["editor"], + metadata={"user_info": {"name": "Tester"}}, + ) + app, _ = _build_app(inject_session=session) + client = TestClient(app) + resp = client.get("/auth/userinfo") + assert resp.status_code == 200 + body = resp.json() + assert body["user_id"] == "u1" + assert body["roles"] == ["editor"] + assert body["user_info"] == {"name": "Tester"} + + +class TestStatusAuthenticated: + """Cover the status body when authenticated.""" + + def test_status_authenticated(self) -> None: + """Authenticated status returns expires_at + roles + authenticated=True.""" + now = time.time() + session = MagicMock( + session_id="s1", + user_id="u1", + roles=["viewer"], + expires_at=now + 3600, + ) + app, _ = _build_app(inject_session=session) + client = TestClient(app) + resp = client.get("/auth/status") + assert resp.status_code == 200 + body = resp.json() + assert body["authenticated"] is True + assert body["user_id"] == "u1" + assert body["expires_at"] == pytest.approx(now + 3600, abs=1) diff --git a/pywry/tests/test_auth_flow_integration.py b/pywry/tests/test_auth_flow_integration.py index 9bfff03..a2c96ff 100644 --- a/pywry/tests/test_auth_flow_integration.py +++ b/pywry/tests/test_auth_flow_integration.py @@ -3,12 +3,14 @@ from __future__ import annotations import asyncio +import builtins import contextlib +import sys import threading import time from typing import Any -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch from urllib.parse import urlencode from urllib.request import urlopen @@ -22,7 +24,7 @@ AuthFlowCancelled, AuthFlowTimeout, ) -from pywry.state.types import AuthFlowState, OAuthTokenSet +from pywry.state.types import AuthFlowResult, AuthFlowState, OAuthTokenSet # ── Helpers ────────────────────────────────────────────────────────── @@ -275,3 +277,251 @@ def test_run_deploy_no_base(self) -> None: flow = AuthFlowManager(provider=provider) url = flow.run_deploy() assert url == "/auth/login" + + +# ── Extended provider error paths ─────────────────────────────────── + + +def _make_flow_provider() -> MagicMock: + """Build a mock provider tailored for AuthFlowManager error-path tests.""" + provider = MagicMock() + provider.__class__.__name__ = "MockProvider" + provider.exchange_code = AsyncMock( + return_value=OAuthTokenSet( + access_token="at", + refresh_token="rt", + expires_in=3600, + ) + ) + provider.get_userinfo = AsyncMock(return_value={"sub": "u1"}) + provider.refresh_tokens = AsyncMock() + provider.revoke_token = AsyncMock() + return provider + + +def _send_callback_after_state_captured( + flow: AuthFlowManager, + captured_state: list[str], + params: dict[str, str], + delay: float = 0.5, +) -> threading.Thread: + """Send a callback to *flow*'s callback server once *captured_state* is populated.""" + + def _send() -> None: + time.sleep(delay) + if flow._callback_server and flow._callback_server._actual_port: + port = flow._callback_server._actual_port + qs = {**params} + if "state" not in qs and captured_state: + qs["state"] = captured_state[0] + with contextlib.suppress(Exception): + urlopen( + f"http://127.0.0.1:{port}/callback?{urlencode(qs)}", + timeout=5, + ) + + t = threading.Thread(target=_send, daemon=True) + t.start() + return t + + +def _capture_state(captured: list[str]) -> Any: + """Build a build_authorize_url side_effect that records the generated state.""" + + def cap(redirect_uri: str, state: str, **_: Any) -> str: + captured.append(state) + return f"https://mock.idp/?state={state}" + + return cap + + +class TestFlowProviderErrorPaths: + """Cover extra paths in AuthFlowManager.run_native().""" + + def test_no_authorization_code(self) -> None: + """Callback without code raises AuthenticationError.""" + provider = _make_flow_provider() + flow = AuthFlowManager(provider=provider, auth_timeout=10.0) + + captured_state: list[str] = [] + provider.build_authorize_url.side_effect = _capture_state(captured_state) + + _send_callback_after_state_captured(flow, captured_state, {}) + + with pytest.raises(AuthenticationError, match="No authorization code"): + flow.run_native() + + def test_get_userinfo_failure_logged(self) -> None: + """If get_userinfo throws, flow continues with empty user_info.""" + provider = _make_flow_provider() + provider.get_userinfo = AsyncMock(side_effect=RuntimeError("userinfo down")) + flow = AuthFlowManager(provider=provider, auth_timeout=10.0) + + captured_state: list[str] = [] + provider.build_authorize_url.side_effect = _capture_state(captured_state) + + _send_callback_after_state_captured(flow, captured_state, {"code": "code"}) + + result = flow.run_native() + assert result.success is True + assert result.user_info == {} + + def test_session_manager_save_tokens_called(self) -> None: + """If session_manager is provided, save_tokens is called with the new tokens.""" + provider = _make_flow_provider() + session_mgr = MagicMock() + session_mgr.save_tokens = AsyncMock() + flow = AuthFlowManager( + provider=provider, + session_manager=session_mgr, + auth_timeout=10.0, + ) + + captured_state: list[str] = [] + provider.build_authorize_url.side_effect = _capture_state(captured_state) + + _send_callback_after_state_captured(flow, captured_state, {"code": "c"}) + + result = flow.run_native() + assert result.success is True + session_mgr.save_tokens.assert_awaited_once() + # Ensure it received the tokens returned by the provider, not something else + saved_tokens = session_mgr.save_tokens.await_args[0][0] + assert saved_tokens.access_token == "at" + + def test_unexpected_exception_wrapped_as_auth_error(self) -> None: + """Unexpected runtime errors during the flow are wrapped in AuthenticationError.""" + provider = _make_flow_provider() + # Make exchange_code raise a non-Auth error (not a TokenError or similar) + provider.exchange_code = AsyncMock(side_effect=RuntimeError("kaboom")) + + flow = AuthFlowManager(provider=provider, auth_timeout=10.0) + captured_state: list[str] = [] + provider.build_authorize_url.side_effect = _capture_state(captured_state) + + _send_callback_after_state_captured(flow, captured_state, {"code": "c"}) + + with pytest.raises(AuthenticationError, match="Authentication flow failed"): + flow.run_native() + + +# ── authenticate() entry point ────────────────────────────────────── + + +class TestFlowAuthenticate: + """Cover AuthFlowManager.authenticate() mode-selection logic.""" + + def test_authenticate_browser_mode_returns_deploy_url(self) -> None: + """Browser/deploy mode returns AuthFlowResult with login URL hint.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + flow = AuthFlowManager(provider=provider) + + # Construct a mock app with WindowMode.BROWSER mode_enum + from pywry.models import WindowMode + + app = MagicMock() + app._mode_enum = WindowMode.BROWSER + + result = flow.authenticate(app) + assert isinstance(result, AuthFlowResult) + assert result.success is False + assert result.error is not None + assert "/auth/login" in result.error + + def test_authenticate_native_with_show_window(self) -> None: + """Custom show_window is passed through to run_native and closed in finally.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + provider.build_authorize_url.return_value = "https://mock.idp/?state=x" + provider.exchange_code = AsyncMock() + provider.get_userinfo = AsyncMock() + provider.revoke_token = AsyncMock() + + flow = AuthFlowManager(provider=provider, auth_timeout=0.3) + + app = MagicMock() + app._mode_enum = None # Not browser mode + + show_window = MagicMock(return_value="lbl") + close_window = MagicMock() + + with pytest.raises(AuthFlowTimeout): + flow.authenticate( + app=app, + show_window=show_window, + close_window=close_window, + ) + show_window.assert_called_once() + close_window.assert_called_once_with("lbl") + + def test_authenticate_default_browser_opener(self) -> None: + """When show_window is None, falls back to webbrowser.open with the auth URL.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + provider.build_authorize_url.return_value = "https://mock.idp/?state=x" + provider.exchange_code = AsyncMock() + provider.get_userinfo = AsyncMock() + provider.revoke_token = AsyncMock() + + flow = AuthFlowManager(provider=provider, auth_timeout=0.3) + + app = MagicMock() + app._mode_enum = None + + with patch("webbrowser.open") as mock_open: + with pytest.raises(AuthFlowTimeout): + flow.authenticate(app=app) + mock_open.assert_called_once() + url_arg = mock_open.call_args[0][0] + assert url_arg.startswith("https://mock.idp/") + + def test_authenticate_no_mode_enum(self) -> None: + """If app._mode_enum is missing, falls through to native.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + provider.build_authorize_url.return_value = "https://mock.idp/?state=x" + provider.exchange_code = AsyncMock() + provider.get_userinfo = AsyncMock() + provider.revoke_token = AsyncMock() + + flow = AuthFlowManager(provider=provider, auth_timeout=0.3) + + # No _mode_enum attr — getattr returns None + app = MagicMock(spec=[]) + + with patch("webbrowser.open"), pytest.raises(AuthFlowTimeout): + flow.authenticate(app=app) + + def test_authenticate_window_mode_import_error(self) -> None: + """If WindowMode import fails, falls through to native flow.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + provider.build_authorize_url.return_value = "https://mock.idp/?state=x" + provider.exchange_code = AsyncMock() + provider.get_userinfo = AsyncMock() + provider.revoke_token = AsyncMock() + + flow = AuthFlowManager(provider=provider, auth_timeout=0.3) + + app = MagicMock() + app._mode_enum = "browser" # truthy value to enter the try-block + + # Force ImportError when WindowMode is imported + original_models = sys.modules.get("pywry.models") + sys.modules.pop("pywry.models", None) + original_import = builtins.__import__ + + def patched_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "pywry.models" or (args and "WindowMode" in (args[2] or [])): + raise ImportError("forced fail") + return original_import(name, *args, **kwargs) + + try: + builtins.__import__ = patched_import + with patch("webbrowser.open"), pytest.raises(AuthFlowTimeout): + flow.authenticate(app=app) + finally: + builtins.__import__ = original_import + if original_models is not None: + sys.modules["pywry.models"] = original_models diff --git a/pywry/tests/test_auth_providers.py b/pywry/tests/test_auth_providers.py index 7728de4..45015b2 100644 --- a/pywry/tests/test_auth_providers.py +++ b/pywry/tests/test_auth_providers.py @@ -3,13 +3,15 @@ from __future__ import annotations import asyncio -import hashlib +import sys import time -from base64 import urlsafe_b64encode +from collections import UserDict +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch from urllib.parse import parse_qs, urlparse +import httpx import pytest from pywry.auth.pkce import PKCEChallenge @@ -24,53 +26,29 @@ from pywry.state.types import OAuthTokenSet -# ── PKCE Tests ────────────────────────────────────────────────────── +# ── Test helpers ──────────────────────────────────────────────────── -class TestPKCEChallenge: - """Tests for PKCEChallenge generation.""" +def _run(coro: Any) -> Any: + """Synchronously drive a coroutine to completion.""" + return asyncio.run(coro) - def test_generate_returns_challenge(self) -> None: - """PKCEChallenge.generate() returns a valid challenge pair.""" - pkce = PKCEChallenge.generate() - assert pkce.verifier - assert pkce.challenge - assert pkce.method == "S256" - def test_verifier_is_url_safe(self) -> None: - """Verifier contains only URL-safe characters.""" - pkce = PKCEChallenge.generate() - # URL-safe base64 chars: A-Z, a-z, 0-9, -, _ - import re +def _mock_async_client(mock_instance: AsyncMock) -> AsyncMock: + """Make AsyncMock work as a context-manager friendly httpx.AsyncClient.""" + mock_instance.__aenter__ = AsyncMock(return_value=mock_instance) + mock_instance.__aexit__ = AsyncMock(return_value=False) + mock_instance.is_closed = False + mock_instance.aclose = AsyncMock() + return mock_instance - assert re.match(r"^[A-Za-z0-9_-]+$", pkce.verifier) - def test_challenge_matches_verifier_sha256(self) -> None: - """Challenge is the base64url SHA-256 of the verifier.""" - pkce = PKCEChallenge.generate() - expected_digest = hashlib.sha256(pkce.verifier.encode("ascii")).digest() - expected_challenge = urlsafe_b64encode(expected_digest).rstrip(b"=").decode("ascii") - assert pkce.challenge == expected_challenge - - def test_generate_uniqueness(self) -> None: - """Each generation produces unique values.""" - a = PKCEChallenge.generate() - b = PKCEChallenge.generate() - assert a.verifier != b.verifier - assert a.challenge != b.challenge - - def test_generate_custom_length(self) -> None: - """Custom length produces different sized verifiers.""" - short = PKCEChallenge.generate(length=32) - long = PKCEChallenge.generate(length=96) - # Longer length should generally produce longer verifier - assert len(short.verifier) < len(long.verifier) - - def test_frozen_dataclass(self) -> None: - """PKCEChallenge is immutable.""" - pkce = PKCEChallenge.generate() - with pytest.raises(AttributeError): - pkce.verifier = "new" # type: ignore +def _make_resp(payload: dict) -> MagicMock: + """Build a mocked httpx response that yields *payload* from .json().""" + resp = MagicMock() + resp.json.return_value = payload + resp.raise_for_status = MagicMock() + return resp # ── Provider URL Building ─────────────────────────────────────────── @@ -541,6 +519,647 @@ def test_unknown_provider(self) -> None: create_provider_from_settings(settings) +# ── Provider lifecycle (close / get_userinfo) ─────────────────────── + + +class TestProviderClose: + """Cover OAuthProvider.close() lifecycle.""" + + def test_close_when_no_client(self) -> None: + """close() is safe when no client has been created.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + # _http_client is None; close() should be a no-op + _run(provider.close()) + assert provider._http_client is None + + def test_close_with_open_client(self) -> None: + """close() awaits aclose() on the client and resets state.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + fake = MagicMock() + fake.is_closed = False + fake.aclose = AsyncMock() + provider._http_client = fake + _run(provider.close()) + fake.aclose.assert_awaited_once() + assert provider._http_client is None + + def test_close_already_closed(self) -> None: + """close() does nothing when client is already closed.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + fake = MagicMock() + fake.is_closed = True + fake.aclose = AsyncMock() + provider._http_client = fake + _run(provider.close()) + fake.aclose.assert_not_called() + + +class TestGetUserInfo: + """Cover OAuthProvider.get_userinfo().""" + + def test_no_userinfo_url_returns_empty(self) -> None: + """Provider without userinfo_url returns empty dict.""" + provider = GenericOIDCProvider( + client_id="c", + token_url="https://x/token", + userinfo_url="", + ) + # Force discovery so it doesn't try to discover with no issuer + provider._discovered = True + result = _run(provider.get_userinfo("at_test")) + assert result == {} + + def test_userinfo_success(self) -> None: + """get_userinfo returns user data and sends Bearer auth header.""" + provider = GenericOIDCProvider( + client_id="c", + token_url="https://x/token", + userinfo_url="https://x/userinfo", + ) + provider._discovered = True + + mock_resp = _make_resp({"sub": "user-1", "email": "u@x.com"}) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + data = _run(provider.get_userinfo("at_test")) + + assert data["sub"] == "user-1" + # Verify Authorization header was sent + call_kwargs = inst.get.call_args[1] + assert call_kwargs["headers"]["Authorization"] == "Bearer at_test" + + +# ── OIDC discovery / JWKS / ID-token validation ───────────────────── + + +class TestOIDCDiscovery: + """Cover GenericOIDCProvider._discover().""" + + def test_discover_skips_when_no_issuer(self) -> None: + """No issuer_url → discovery short-circuits.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + _run(provider._discover()) + assert provider._discovered is False + + def test_discover_skips_when_already_discovered(self) -> None: + """Already discovered → short-circuits.""" + provider = GenericOIDCProvider( + client_id="c", + issuer_url="https://idp.example.com", + token_url="https://x/token", + ) + provider._discovered = True + with patch("httpx.AsyncClient") as mock_client: + _run(provider._discover()) + mock_client.assert_not_called() + + def test_discover_success_populates_endpoints(self) -> None: + """Discovery populates endpoints from well-known config.""" + provider = GenericOIDCProvider( + client_id="c", + issuer_url="https://idp.example.com", + ) + + config_resp = _make_resp( + { + "issuer": "https://idp.example.com", + "authorization_endpoint": "https://idp.example.com/authorize", + "token_endpoint": "https://idp.example.com/token", + "userinfo_endpoint": "https://idp.example.com/userinfo", + "revocation_endpoint": "https://idp.example.com/revoke", + "jwks_uri": "https://idp.example.com/jwks", + } + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(return_value=config_resp) + _mock_async_client(inst) + mock_client.return_value = inst + _run(provider._discover()) + + assert provider._discovered + assert provider.authorize_url == "https://idp.example.com/authorize" + assert provider.token_url == "https://idp.example.com/token" + assert provider.userinfo_url == "https://idp.example.com/userinfo" + assert provider.revocation_url == "https://idp.example.com/revoke" + assert provider._jwks_uri == "https://idp.example.com/jwks" + + def test_discover_issuer_mismatch(self) -> None: + """Issuer mismatch raises AuthenticationError.""" + provider = GenericOIDCProvider( + client_id="c", + issuer_url="https://idp.example.com", + ) + config_resp = _make_resp({"issuer": "https://wrong-idp.com"}) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(return_value=config_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(AuthenticationError, match="issuer mismatch"): + _run(provider._discover()) + + def test_discover_http_error_logs_warning(self) -> None: + """HTTP error during discovery is logged but not raised.""" + provider = GenericOIDCProvider( + client_id="c", + issuer_url="https://idp.example.com", + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(side_effect=httpx.HTTPError("network down")) + _mock_async_client(inst) + mock_client.return_value = inst + + _run(provider._discover()) + + assert not provider._discovered + + +class TestJWKSFetch: + """Cover GenericOIDCProvider._fetch_jwks().""" + + def test_fetch_jwks_no_uri(self) -> None: + """No JWKS URI raises TokenError.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + with pytest.raises(TokenError, match="JWKS URI"): + _run(provider._fetch_jwks()) + + def test_fetch_jwks_returns_cached(self) -> None: + """Cached JWKS is returned without HTTP call.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + provider._jwks_data = {"keys": [{"kty": "RSA"}]} + result = _run(provider._fetch_jwks()) + assert result == {"keys": [{"kty": "RSA"}]} + + def test_fetch_jwks_http_success(self) -> None: + """Fresh JWKS fetch makes HTTP request and caches the result.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + provider._jwks_uri = "https://idp.example.com/jwks" + + mock_resp = _make_resp({"keys": [{"kty": "RSA", "kid": "k1"}]}) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + result = _run(provider._fetch_jwks()) + + assert result["keys"][0]["kid"] == "k1" + # Cached for next call + assert provider._jwks_data is not None + + +class TestValidateIDToken: + """Cover GenericOIDCProvider.validate_id_token().""" + + def test_no_authlib_raises(self) -> None: + """Missing authlib raises TokenError.""" + provider = GenericOIDCProvider(client_id="c", token_url="https://x/token") + with ( + patch("pywry.auth.providers._HAS_AUTHLIB", False), + pytest.raises(TokenError, match="authlib"), + ): + _run(provider.validate_id_token("dummy.token")) + + def test_validation_failure_wrapped_as_token_error(self) -> None: + """Failed validation wraps exception in TokenError.""" + provider = GenericOIDCProvider( + client_id="c", + issuer_url="https://idp.example.com", + token_url="https://idp.example.com/token", + ) + provider._discovered = True + provider._jwks_data = {"keys": []} + + # JWT decode will fail because token is gibberish + with pytest.raises(TokenError, match="ID token validation failed"): + _run(provider.validate_id_token("not.a.real.jwt")) + + def test_validation_with_nonce_option(self) -> None: + """Nonce parameter is wired into claims_options.""" + provider = GenericOIDCProvider( + client_id="c", + issuer_url="https://idp.example.com", + token_url="https://idp.example.com/token", + ) + provider._discovered = True + provider._jwks_data = {"keys": []} + + # Should fail validation but not before processing nonce + with pytest.raises(TokenError): + _run(provider.validate_id_token("invalid.token.value", nonce="my-nonce")) + + def test_validation_success_returns_claims(self) -> None: + """Successful validation returns the dict of claims.""" + provider = GenericOIDCProvider( + client_id="test-client", + issuer_url="https://idp.example.com", + token_url="https://idp.example.com/token", + ) + provider._discovered = True + provider._jwks_data = {"keys": []} + + class _FakeClaims(UserDict): + def validate(self) -> None: + """No-op validate to mimic authlib's claims object.""" + + fake_claims_obj = _FakeClaims({"sub": "u1", "iss": "https://idp.example.com"}) + + with ( + patch("pywry.auth.providers.JsonWebToken") as mock_jwt_cls, + patch("pywry.auth.providers.JsonWebKey") as mock_jwk, + ): + mock_jwt = MagicMock() + mock_jwt.decode.return_value = fake_claims_obj + mock_jwt_cls.return_value = mock_jwt + mock_jwk.import_key_set = MagicMock(return_value={}) + claims = _run(provider.validate_id_token("hdr.payload.sig")) + assert claims["sub"] == "u1" + + +# ── exchange_code / refresh_tokens error paths ────────────────────── + + +class TestExchangeCodeErrors: + """Cover error paths in GenericOIDCProvider.exchange_code().""" + + def test_no_token_url_raises(self) -> None: + """No token_url after discovery raises TokenError.""" + provider = GenericOIDCProvider(client_id="c", token_url="") + provider._discovered = True + with pytest.raises(TokenError, match="Token URL not configured"): + _run(provider.exchange_code("code", "http://localhost/cb")) + + def test_http_error_raises_token_error(self) -> None: + """Non-status HTTP error raises TokenError.""" + provider = GenericOIDCProvider( + client_id="c", + client_secret="s", + token_url="https://idp.example.com/token", + ) + provider._discovered = True + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(side_effect=httpx.HTTPError("connection refused")) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(TokenError, match="Token exchange request failed"): + _run(provider.exchange_code("code", "http://localhost/cb")) + + def test_id_token_validation_invoked(self) -> None: + """ID token validation is called when require_id_token_validation.""" + provider = GenericOIDCProvider( + client_id="c", + client_secret="s", + token_url="https://idp.example.com/token", + require_id_token_validation=True, + ) + provider._discovered = True + + mock_resp = _make_resp( + { + "access_token": "at", + "id_token": "header.payload.signature", + "expires_in": 3600, + } + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + with patch.object( + provider, "validate_id_token", new_callable=AsyncMock + ) as mock_validate: + tokens = _run(provider.exchange_code("code", "http://localhost/cb")) + + mock_validate.assert_awaited_once_with("header.payload.signature", nonce=None) + assert tokens.id_token == "header.payload.signature" + + +class TestRefreshTokensErrors: + """Cover error paths in GenericOIDCProvider.refresh_tokens().""" + + def test_no_token_url_raises(self) -> None: + """No token_url after discovery raises TokenRefreshError.""" + provider = GenericOIDCProvider(client_id="c", token_url="") + provider._discovered = True + with pytest.raises(TokenRefreshError, match="Token URL not configured"): + _run(provider.refresh_tokens("rt_x")) + + def test_http_error_raises_refresh_error(self) -> None: + """Generic HTTP error raises TokenRefreshError.""" + provider = GenericOIDCProvider( + client_id="c", + token_url="https://idp.example.com/token", + ) + provider._discovered = True + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(side_effect=httpx.HTTPError("dns failure")) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(TokenRefreshError, match="Token refresh request failed"): + _run(provider.refresh_tokens("rt_x")) + + +# ── GoogleProvider — build_authorize_url ──────────────────────────── + + +class TestGoogleAuthorize: + """Cover GoogleProvider.build_authorize_url() merge of extra_params.""" + + def test_extra_params_merged(self) -> None: + """Google merges custom extra_params with default access_type/prompt.""" + provider = GoogleProvider(client_id="c") + url = provider.build_authorize_url( + redirect_uri="http://localhost/cb", + state="s", + extra_params={"login_hint": "user@example.com"}, + ) + # Both default and custom params present + assert "access_type=offline" in url + assert "prompt=consent" in url + assert "login_hint=user%40example.com" in url + + +# ── GitHubProvider — exchange / refresh / revoke error paths ──────── + + +class TestGitHubExchangeErrors: + """Cover error paths in GitHubProvider.exchange_code().""" + + def test_status_error(self) -> None: + """HTTPStatusError yields TokenError with status code.""" + provider = GitHubProvider(client_id="c", client_secret="s") + mock_resp = MagicMock() + mock_resp.status_code = 400 + mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( + "Bad Request", request=MagicMock(), response=mock_resp + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(TokenError, match="GitHub token exchange failed"): + _run(provider.exchange_code("code", "http://localhost/cb")) + + def test_http_error(self) -> None: + """Generic HTTPError yields TokenError.""" + provider = GitHubProvider(client_id="c", client_secret="s") + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(side_effect=httpx.HTTPError("dns")) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(TokenError, match="GitHub token exchange request failed"): + _run(provider.exchange_code("code", "http://localhost/cb")) + + def test_success(self) -> None: + """Successful GitHub exchange returns OAuthTokenSet with all fields.""" + provider = GitHubProvider(client_id="c", client_secret="s") + + mock_resp = _make_resp( + { + "access_token": "ghp_abc123", + "token_type": "bearer", + "scope": "read:user,user:email", + "expires_in": 28800, + "refresh_token": "ghr_xyz", + } + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + tokens = _run(provider.exchange_code("code", "http://localhost/cb")) + + assert tokens.access_token == "ghp_abc123" + assert tokens.refresh_token == "ghr_xyz" + assert tokens.expires_in == 28800 + assert tokens.token_type == "bearer" + + +class TestGitHubRefreshTokens: + """Cover GitHubProvider.refresh_tokens().""" + + def test_success(self) -> None: + """Successful GitHub refresh returns new tokens with all fields.""" + provider = GitHubProvider(client_id="c", client_secret="s") + + mock_resp = _make_resp( + { + "access_token": "at_refreshed", + "refresh_token": "rt_new", + "token_type": "bearer", + "expires_in": 7200, + "scope": "read:user", + } + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + tokens = _run(provider.refresh_tokens("rt_old")) + + assert tokens.access_token == "at_refreshed" + assert tokens.refresh_token == "rt_new" + assert tokens.expires_in == 7200 + + def test_http_error(self) -> None: + """HTTP error during GitHub refresh raises TokenRefreshError.""" + provider = GitHubProvider(client_id="c", client_secret="s") + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(side_effect=httpx.HTTPError("connect")) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(TokenRefreshError, match="GitHub token refresh failed"): + _run(provider.refresh_tokens("rt_x")) + + def test_error_in_body(self) -> None: + """GitHub returns error in JSON body during refresh.""" + provider = GitHubProvider(client_id="c", client_secret="s") + + mock_resp = _make_resp({"error": "bad_refresh_token", "error_description": "invalid token"}) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.post = AsyncMock(return_value=mock_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(TokenRefreshError, match="invalid token"): + _run(provider.refresh_tokens("rt_x")) + + +class TestGitHubRevokeError: + """Cover GitHubProvider.revoke_token() exception path.""" + + def test_revoke_http_error(self) -> None: + """HTTPError during revoke returns False.""" + provider = GitHubProvider(client_id="gh_id", client_secret="gh_sec") + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.request = AsyncMock(side_effect=httpx.HTTPError("conn")) + _mock_async_client(inst) + mock_client.return_value = inst + result = _run(provider.revoke_token("ghp_x")) + assert result is False + + +# ── MicrosoftProvider — _discover branches ────────────────────────── + + +class TestMicrosoftDiscover: + """Cover MicrosoftProvider._discover() variations.""" + + def test_discover_skips_when_already_discovered(self) -> None: + """Already discovered MS provider short-circuits.""" + provider = MicrosoftProvider(client_id="c", tenant_id="my-tenant") + provider._discovered = True + with patch("httpx.AsyncClient") as mock_client: + _run(provider._discover()) + mock_client.assert_not_called() + + def test_discover_skips_when_no_issuer(self) -> None: + """MS provider with empty issuer_url short-circuits.""" + provider = MicrosoftProvider(client_id="c") + provider.issuer_url = "" + with patch("httpx.AsyncClient") as mock_client: + _run(provider._discover()) + mock_client.assert_not_called() + + def test_discover_with_tenantid_placeholder(self) -> None: + """MS issuer with `{tenantid}` placeholder is normalized.""" + provider = MicrosoftProvider(client_id="c", tenant_id="my-tenant") + # Don't set explicit URLs - allow discovery to populate them + provider.authorize_url = "" + provider.token_url = "" + provider.userinfo_url = "" + provider.revocation_url = "" + + config_resp = _make_resp( + { + # Microsoft uses {tenantid} placeholder + "issuer": "https://login.microsoftonline.com/{tenantid}/v2.0", + "authorization_endpoint": ( + "https://login.microsoftonline.com/my-tenant/oauth2/v2.0/authorize" + ), + "token_endpoint": ("https://login.microsoftonline.com/my-tenant/oauth2/v2.0/token"), + "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", + "revocation_endpoint": "", + "jwks_uri": ("https://login.microsoftonline.com/my-tenant/discovery/v2.0/keys"), + } + ) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(return_value=config_resp) + _mock_async_client(inst) + mock_client.return_value = inst + _run(provider._discover()) + + assert provider._discovered + assert "/authorize" in provider.authorize_url + assert "/token" in provider.token_url + assert provider.userinfo_url == "https://graph.microsoft.com/oidc/userinfo" + assert provider._jwks_uri.endswith("/discovery/v2.0/keys") + + def test_discover_real_issuer_mismatch_raises(self) -> None: + """A non-matching MS issuer (after normalization) still raises.""" + provider = MicrosoftProvider(client_id="c", tenant_id="my-tenant") + config_resp = _make_resp({"issuer": "https://login.microsoftonline.com/wrong-tenant/v2.0"}) + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(return_value=config_resp) + _mock_async_client(inst) + mock_client.return_value = inst + + with pytest.raises(AuthenticationError, match="issuer mismatch"): + _run(provider._discover()) + + def test_discover_http_error_logs_warning(self) -> None: + """HTTP error during MS discovery is logged but not raised.""" + provider = MicrosoftProvider(client_id="c", tenant_id="my-tenant") + + with patch("httpx.AsyncClient") as mock_client: + inst = AsyncMock() + inst.get = AsyncMock(side_effect=httpx.HTTPError("network down")) + _mock_async_client(inst) + mock_client.return_value = inst + _run(provider._discover()) + + assert not provider._discovered + + +# ── Module-level — authlib import fallback ────────────────────────── + + +class TestAuthlibImportFallback: + """Cover the ImportError fallback when authlib is missing.""" + + def test_reimport_with_missing_authlib(self) -> None: + """When authlib import fails at module load, _HAS_AUTHLIB becomes False.""" + import importlib + + # Save the existing module + original = sys.modules.get("pywry.auth.providers") + + # Remove authlib + jose modules to force ImportError + authlib_modules = [k for k in list(sys.modules) if k.startswith("authlib")] + saved = {k: sys.modules[k] for k in authlib_modules} + for k in authlib_modules: + sys.modules[k] = None # type: ignore[assignment] + + try: + sys.modules.pop("pywry.auth.providers", None) + # Re-import — should trigger the ImportError fallback path + reimported = importlib.import_module("pywry.auth.providers") + assert reimported._HAS_AUTHLIB is False + finally: + # Restore authlib modules + for k, v in saved.items(): + sys.modules[k] = v + # Restore the original providers module + if original is not None: + sys.modules["pywry.auth.providers"] = original + else: + sys.modules.pop("pywry.auth.providers", None) + importlib.import_module("pywry.auth.providers") + + # ── OAuthTokenSet ──────────────────────────────────────────────────── diff --git a/pywry/tests/test_auth_session.py b/pywry/tests/test_auth_session.py index fdb864f..dbb1b7b 100644 --- a/pywry/tests/test_auth_session.py +++ b/pywry/tests/test_auth_session.py @@ -297,3 +297,177 @@ def test_cancel_refresh_timer( _run(mgr.save_tokens(valid_tokens)) mgr._cancel_refresh_timer() assert mgr._refresh_timer is None + + +# ── Extended refresh paths ────────────────────────────────────────── + + +def _make_provider_with_refresh( + refresh_side_effect: object | None = None, +) -> MagicMock: + """Build a mock provider whose refresh_tokens returns a fresh OAuthTokenSet.""" + provider = MagicMock() + provider.__class__.__name__ = "FakeProvider" + provider.revoke_token = AsyncMock() + if refresh_side_effect is not None: + provider.refresh_tokens = AsyncMock(side_effect=refresh_side_effect) + else: + provider.refresh_tokens = AsyncMock( + return_value=OAuthTokenSet( + access_token="at_new", + refresh_token="rt_new", + expires_in=3600, + ) + ) + return provider + + +class TestSessionManagerRefreshPaths: + """Cover edge paths in SessionManager.refresh() / get_access_token().""" + + def test_initialize_refresh_failure_returns_none(self) -> None: + """If refresh fails on init, return None.""" + provider = _make_provider_with_refresh(refresh_side_effect=RuntimeError("provider down")) + store = MemoryTokenStore() + expired = OAuthTokenSet( + access_token="at_old", + refresh_token="rt_old", + expires_in=3600, + issued_at=time.time() - 7200, + ) + _run(store.save("default", expired)) + + mgr = SessionManager(provider, store) + result = _run(mgr.initialize()) + assert result is None + + def test_get_access_token_refreshes_when_expired(self) -> None: + """get_access_token triggers refresh when token is expired.""" + provider = _make_provider_with_refresh() + store = MemoryTokenStore() + expired = OAuthTokenSet( + access_token="at_old", + refresh_token="rt_old", + expires_in=3600, + issued_at=time.time() - 7200, + ) + mgr = SessionManager(provider, store) + _run(mgr.save_tokens(expired)) + # Expired -> triggers refresh -> returns at_new + token = _run(mgr.get_access_token()) + assert token == "at_new" + + def test_refresh_loads_from_store_when_no_current(self) -> None: + """If _current_tokens is None, refresh() loads from store.""" + provider = _make_provider_with_refresh() + store = MemoryTokenStore() + tokens = OAuthTokenSet( + access_token="at", + refresh_token="rt", + expires_in=3600, + issued_at=time.time(), + ) + _run(store.save("default", tokens)) + mgr = SessionManager(provider, store) + # _current_tokens is None until we initialize + new_tokens = _run(mgr.refresh()) + assert new_tokens.access_token == "at_new" + + def test_refresh_failure_calls_reauth_callback(self) -> None: + """When provider.refresh_tokens raises, on_reauth_required is called.""" + provider = _make_provider_with_refresh(refresh_side_effect=RuntimeError("provider down")) + store = MemoryTokenStore() + tokens = OAuthTokenSet( + access_token="at", + refresh_token="rt", + expires_in=3600, + issued_at=time.time(), + ) + reauth_calls: list[bool] = [] + mgr = SessionManager( + provider, + store, + on_reauth_required=lambda: reauth_calls.append(True), + ) + _run(mgr.save_tokens(tokens)) + with pytest.raises(TokenRefreshError, match="Token refresh failed"): + _run(mgr.refresh()) + assert reauth_calls == [True] + + def test_refresh_fallback_no_refresh_token_with_reauth(self) -> None: + """If no refresh_token and on_reauth_required is set, callback fires first.""" + provider = _make_provider_with_refresh() + store = MemoryTokenStore() + no_refresh = OAuthTokenSet( + access_token="at", + refresh_token=None, + expires_in=3600, + issued_at=time.time(), + ) + reauth_calls: list[bool] = [] + mgr = SessionManager( + provider, + store, + on_reauth_required=lambda: reauth_calls.append(True), + ) + _run(mgr.save_tokens(no_refresh)) + with pytest.raises(TokenRefreshError, match="Re-authentication required"): + _run(mgr.refresh()) + assert reauth_calls == [True] + + +class TestSessionManagerScheduleRefresh: + """Cover branch in _schedule_refresh where delay <= 0.""" + + def test_schedule_when_already_expired_uses_immediate(self) -> None: + """When token already expired (delay <= 0), schedules with delay=1.0.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + store = MemoryTokenStore() + # Token already expired (issued an hour+ ago, but expires_in is 3600) + expired = OAuthTokenSet( + access_token="at", + refresh_token="rt", + expires_in=3600, + issued_at=time.time() - 7200, + ) + mgr = SessionManager(provider, store, refresh_buffer_seconds=60) + # Calling _schedule_refresh directly to avoid triggering refresh logic + mgr._schedule_refresh(expired) + assert mgr._refresh_timer is not None + # Cancel before it fires + mgr._cancel_refresh_timer() + + +class TestSessionManagerBackgroundRefresh: + """Cover SessionManager._do_background_refresh exception path.""" + + def test_background_refresh_failure_calls_reauth(self) -> None: + """Background refresh failure invokes on_reauth_required.""" + provider = MagicMock() + provider.__class__.__name__ = "Fake" + provider.refresh_tokens = AsyncMock(side_effect=RuntimeError("provider down")) + provider.revoke_token = AsyncMock() + store = MemoryTokenStore() + reauth_calls: list[bool] = [] + + mgr = SessionManager( + provider, + store, + on_reauth_required=lambda: reauth_calls.append(True), + ) + # Pre-populate with valid tokens (so _current_tokens is set) + tokens = OAuthTokenSet( + access_token="at", + refresh_token="rt", + expires_in=3600, + ) + _run(mgr.save_tokens(tokens)) + # Cancel scheduled timer so we drive it manually + mgr._cancel_refresh_timer() + + # Run the background refresh — should not raise + mgr._do_background_refresh() + # refresh() raises -> exception handler calls reauth; the inner refresh() + # may also call it for no-refresh-token or provider-down branches. + assert len(reauth_calls) >= 1 diff --git a/pywry/tests/test_auth_token_store.py b/pywry/tests/test_auth_token_store.py index 54e7fa0..c1db8d1 100644 --- a/pywry/tests/test_auth_token_store.py +++ b/pywry/tests/test_auth_token_store.py @@ -4,14 +4,18 @@ import asyncio import json +import sys import time -from unittest.mock import patch +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pywry.auth.token_store import ( + KeyringTokenStore, MemoryTokenStore, + RedisTokenStore, _deserialize_tokens, _serialize_tokens, get_token_store, @@ -135,16 +139,242 @@ def test_overwrite(self, memory_store: MemoryTokenStore, sample_tokens: OAuthTok # ── KeyringTokenStore ─────────────────────────────────────────────── -class TestKeyringTokenStore: - """Tests for KeyringTokenStore with mocked keyring.""" +@pytest.fixture() +def fake_keyring_store() -> tuple[KeyringTokenStore, MagicMock]: + """Build a KeyringTokenStore with its _keyring attribute replaced by a mock.""" + fake = MagicMock() + store = KeyringTokenStore(service_name="pywry-test") + store._keyring = fake + return store, fake + + +class TestKeyringTokenStoreImport: + """Tests for KeyringTokenStore optional-dependency handling.""" def test_import_error(self) -> None: """Missing keyring raises ImportError.""" - with patch.dict("sys.modules", {"keyring": None}): - from pywry.auth.token_store import KeyringTokenStore + with ( + patch.dict("sys.modules", {"keyring": None}), + pytest.raises(ImportError, match="keyring"), + ): + KeyringTokenStore() + + +class TestKeyringTokenStore: + """CRUD operations against KeyringTokenStore using a mocked keyring backend.""" + + def test_save_calls_set_password( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """save() invokes keyring.set_password with serialized tokens.""" + store, fake = fake_keyring_store + fake.set_password = MagicMock() + tokens = OAuthTokenSet(access_token="at_k", expires_in=3600) + asyncio.run(store.save("user1", tokens)) + + fake.set_password.assert_called_once() + args = fake.set_password.call_args[0] + assert args[0] == "pywry-test" + assert args[1] == "user1" + assert "at_k" in args[2] + + def test_load_returns_tokens( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """load() round-trips tokens from keyring.""" + store, fake = fake_keyring_store + tokens = OAuthTokenSet(access_token="at_k", expires_in=3600) + fake.get_password = MagicMock(return_value=_serialize_tokens(tokens)) + result = asyncio.run(store.load("user1")) + assert result is not None + assert result.access_token == "at_k" + + def test_load_missing_returns_none( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """load() returns None for missing key.""" + store, fake = fake_keyring_store + fake.get_password = MagicMock(return_value=None) + result = asyncio.run(store.load("missing")) + assert result is None + + def test_delete_calls_delete_password( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """delete() invokes keyring.delete_password.""" + store, fake = fake_keyring_store + fake.delete_password = MagicMock() + asyncio.run(store.delete("user1")) + fake.delete_password.assert_called_once() + + def test_delete_swallows_exceptions( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """delete() swallows keyring errors instead of bubbling them up.""" + store, fake = fake_keyring_store + fake.delete_password = MagicMock(side_effect=RuntimeError("not found")) + # Must not raise + asyncio.run(store.delete("missing")) + + def test_exists_via_load( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """exists() uses load() under the hood and returns True/False accordingly.""" + store, fake = fake_keyring_store + fake.get_password = MagicMock( + return_value=_serialize_tokens(OAuthTokenSet(access_token="x")) + ) + assert asyncio.run(store.exists("u")) is True + + fake.get_password = MagicMock(return_value=None) + assert asyncio.run(store.exists("u")) is False - with pytest.raises(ImportError, match="keyring"): - KeyringTokenStore() + def test_list_keys_returns_empty( + self, + fake_keyring_store: tuple[KeyringTokenStore, MagicMock], + ) -> None: + """list_keys() always returns [] since keyring offers no enumeration API.""" + store, _ = fake_keyring_store + keys = asyncio.run(store.list_keys()) + assert keys == [] + + +# ── RedisTokenStore ───────────────────────────────────────────────── + + +@pytest.fixture() +def fake_redis_store() -> tuple[RedisTokenStore, AsyncMock]: + """Build a RedisTokenStore with its _redis attribute replaced by a mock.""" + store = RedisTokenStore(redis_url="redis://localhost:6379/0", prefix="pywry-test") + fake = AsyncMock() + store._redis = fake + return store, fake + + +class TestRedisTokenStoreImport: + """Tests for RedisTokenStore optional-dependency handling.""" + + def test_missing_redis_raises(self) -> None: + """RedisTokenStore raises ImportError when redis.asyncio is absent.""" + original = sys.modules.get("redis.asyncio") + sys.modules["redis.asyncio"] = None # type: ignore[assignment] + try: + with pytest.raises(ImportError, match="Redis backend"): + RedisTokenStore() + finally: + if original is not None: + sys.modules["redis.asyncio"] = original + else: + sys.modules.pop("redis.asyncio", None) + + +class TestRedisTokenStore: + """CRUD operations against RedisTokenStore using a mocked async client.""" + + def test_key_format(self, fake_redis_store: tuple[RedisTokenStore, AsyncMock]) -> None: + """_key() builds the correct Redis key.""" + store, _ = fake_redis_store + assert store._key("u1") == "pywry-test:oauth:tokens:u1" + + def test_save_with_expiry_uses_setex( + self, + fake_redis_store: tuple[RedisTokenStore, AsyncMock], + ) -> None: + """save() with expires_in uses setex with TTL + 300s buffer.""" + store, fake = fake_redis_store + fake.setex = AsyncMock(return_value=True) + tokens = OAuthTokenSet(access_token="at", expires_in=3600) + asyncio.run(store.save("user1", tokens)) + fake.setex.assert_awaited_once() + args = fake.setex.call_args[0] + assert args[0] == "pywry-test:oauth:tokens:user1" + assert args[1] == 3600 + 300 # buffer + assert "at" in args[2] + + def test_save_no_expiry_uses_set( + self, + fake_redis_store: tuple[RedisTokenStore, AsyncMock], + ) -> None: + """save() without expires_in uses plain SET (no TTL).""" + store, fake = fake_redis_store + fake.set = AsyncMock(return_value=True) + tokens = OAuthTokenSet(access_token="at_noexp", expires_in=None) + asyncio.run(store.save("u1", tokens)) + fake.set.assert_awaited_once() + + def test_save_zero_expiry_uses_set( + self, + fake_redis_store: tuple[RedisTokenStore, AsyncMock], + ) -> None: + """save() with expires_in=0 uses plain SET (not SETEX).""" + store, fake = fake_redis_store + fake.set = AsyncMock(return_value=True) + tokens = OAuthTokenSet(access_token="at_zero", expires_in=0) + asyncio.run(store.save("u1", tokens)) + fake.set.assert_awaited_once() + + def test_load_returns_tokens( + self, + fake_redis_store: tuple[RedisTokenStore, AsyncMock], + ) -> None: + """load() deserializes tokens from Redis.""" + store, fake = fake_redis_store + tokens = OAuthTokenSet(access_token="at_loaded", expires_in=600) + fake.get = AsyncMock(return_value=_serialize_tokens(tokens)) + result = asyncio.run(store.load("u1")) + assert result is not None + assert result.access_token == "at_loaded" + + def test_load_missing( + self, + fake_redis_store: tuple[RedisTokenStore, AsyncMock], + ) -> None: + """load() returns None when Redis key missing.""" + store, fake = fake_redis_store + fake.get = AsyncMock(return_value=None) + result = asyncio.run(store.load("missing")) + assert result is None + + def test_delete(self, fake_redis_store: tuple[RedisTokenStore, AsyncMock]) -> None: + """delete() invokes Redis DELETE with the prefixed key.""" + store, fake = fake_redis_store + fake.delete = AsyncMock(return_value=1) + asyncio.run(store.delete("u1")) + fake.delete.assert_awaited_with("pywry-test:oauth:tokens:u1") + + def test_exists(self, fake_redis_store: tuple[RedisTokenStore, AsyncMock]) -> None: + """exists() reflects Redis EXISTS as a boolean.""" + store, fake = fake_redis_store + fake.exists = AsyncMock(return_value=1) + assert asyncio.run(store.exists("u1")) is True + + fake.exists = AsyncMock(return_value=0) + assert asyncio.run(store.exists("u1")) is False + + def test_list_keys_via_scan_iter( + self, + fake_redis_store: tuple[RedisTokenStore, AsyncMock], + ) -> None: + """list_keys() iterates Redis SCAN and strips the prefix.""" + store, fake = fake_redis_store + + async def fake_scan_iter(match: str) -> Any: + for key in [ + "pywry-test:oauth:tokens:user1", + "pywry-test:oauth:tokens:user2", + ]: + yield key + + fake.scan_iter = fake_scan_iter + + keys = asyncio.run(store.list_keys()) + assert sorted(keys) == ["user1", "user2"] # ── get_token_store factory ───────────────────────────────────────── @@ -158,15 +388,46 @@ def test_memory_backend(self) -> None: reset_token_store() store = get_token_store("memory") assert isinstance(store, MemoryTokenStore) + reset_token_store() def test_default_is_memory(self) -> None: """Default backend is memory.""" reset_token_store() store = get_token_store() assert isinstance(store, MemoryTokenStore) + reset_token_store() def test_unknown_backend_raises(self) -> None: """Unknown backend raises ValueError.""" reset_token_store() with pytest.raises(ValueError, match="Unknown"): get_token_store("nonexistent") + + def test_get_returns_cached_singleton(self) -> None: + """Subsequent calls return the same instance.""" + reset_token_store() + first = get_token_store("memory") + second = get_token_store("memory") + assert first is second + reset_token_store() + + def test_get_keyring_backend(self) -> None: + """`keyring` backend instantiates KeyringTokenStore.""" + reset_token_store() + store = get_token_store("keyring", service_name="pywry-test-extra") + assert isinstance(store, KeyringTokenStore) + assert store._service_name == "pywry-test-extra" + reset_token_store() + + def test_get_redis_backend(self) -> None: + """`redis` backend instantiates RedisTokenStore.""" + reset_token_store() + store = get_token_store( + "redis", + redis_url="redis://localhost:6379/0", + prefix="px-test", + pool_size=5, + ) + assert isinstance(store, RedisTokenStore) + assert store._prefix == "px-test" + reset_token_store() diff --git a/pywry/tests/test_browser_mode_e2e.py b/pywry/tests/test_browser_mode_e2e.py index 0f8b247..76b3204 100644 --- a/pywry/tests/test_browser_mode_e2e.py +++ b/pywry/tests/test_browser_mode_e2e.py @@ -39,7 +39,6 @@ from pywry.config import clear_settings from pywry.inline import ( - HAS_FASTAPI, InlineWidget, _start_server, _state, @@ -47,11 +46,8 @@ show_plotly, stop_server, ) -from pywry.models import WindowMode - -# Skip all tests if FastAPI not installed -pytestmark = pytest.mark.skipif(not HAS_FASTAPI, reason="FastAPI not installed") +from pywry.models import WindowMode def wait_for_port_release(port: int, host: str = "127.0.0.1", timeout: float = 5.0) -> bool: diff --git a/pywry/tests/test_chat.py b/pywry/tests/test_chat.py index 4bbd5a4..fbb1ab5 100644 --- a/pywry/tests/test_chat.py +++ b/pywry/tests/test_chat.py @@ -1,40 +1,28 @@ """Unit tests for the chat component. Tests cover: -- ACP content block models (TextPart, ImagePart, AudioPart, etc.) -- ACPToolCall model -- ChatMessage, ChatThread, ChatConfig -- GenerationHandle (cancel, append_chunk, partial_content, is_expired) - ChatStateMixin: all chat state management methods - ChatStore ABC + MemoryChatStore implementation - Chat builder functions -- ACPCommand model +- build_chat_html +- Provider factory +- Session primitives +- SessionUpdate types +- Artifacts validation +- Permissions + +Model tests live in ``test_chat_models.py``. """ from __future__ import annotations -import time - from typing import Any import pytest from pywry.chat import ( - GENERATION_HANDLE_TTL, - MAX_CONTENT_LENGTH, - ACPCommand, - ACPToolCall, - AudioPart, - ChatConfig, ChatMessage, ChatThread, - ChatWidgetConfig, - EmbeddedResource, - EmbeddedResourcePart, - GenerationHandle, - ImagePart, - ResourceLinkPart, - TextPart, build_chat_html, ) from pywry.state_mixins import ChatStateMixin, EmittingWidget @@ -65,296 +53,6 @@ class MockChatWidget(MockEmitter, ChatStateMixin): """Mock widget combining emitter with ChatStateMixin.""" -# ============================================================================= -# ChatMessage Tests -# ============================================================================= - - -class TestChatMessage: - """Test ChatMessage model.""" - - def test_basic_creation(self) -> None: - msg = ChatMessage(role="user", content="Hello") - assert msg.role == "user" - assert msg.text_content() == "Hello" - assert msg.message_id - assert msg.stopped is False - - def test_string_content(self) -> None: - msg = ChatMessage(role="assistant", content="Hi there") - assert msg.text_content() == "Hi there" - - def test_list_content_text_parts(self) -> None: - msg = ChatMessage( - role="assistant", - content=[ - TextPart(text="Hello "), - TextPart(text="world"), - ], - ) - assert msg.text_content() == "Hello world" - - def test_list_content_mixed_parts(self) -> None: - msg = ChatMessage( - role="assistant", - content=[ - TextPart(text="See image: "), - ImagePart(data="base64data", mimeType="image/png"), - ], - ) - assert msg.text_content() == "See image: " - - def test_content_length_validation(self) -> None: - msg = ChatMessage(role="user", content="x" * 100) - assert len(msg.text_content()) == 100 - - def test_content_too_long_raises(self) -> None: - from pydantic import ValidationError - - with pytest.raises(ValidationError): - ChatMessage(role="user", content="x" * (MAX_CONTENT_LENGTH + 1)) - - def test_tool_calls(self) -> None: - msg = ChatMessage( - role="assistant", - content="I'll search for that.", - tool_calls=[ - ACPToolCall( - toolCallId="call_1", - name="search", - kind="fetch", - arguments={"query": "test"}, - ), - ], - ) - assert len(msg.tool_calls) == 1 - assert msg.tool_calls[0].name == "search" - assert msg.tool_calls[0].kind == "fetch" - - def test_stopped_field(self) -> None: - msg = ChatMessage(role="assistant", content="Partial", stopped=True) - assert msg.stopped is True - - def test_metadata(self) -> None: - msg = ChatMessage( - role="assistant", - content="Result", - metadata={"model": "gpt-4", "usage": {"tokens": 42}}, - ) - assert msg.metadata["model"] == "gpt-4" - - -class TestChatThread: - """Test ChatThread model.""" - - def test_creation(self) -> None: - thread = ChatThread(thread_id="t1", title="Test Thread") - assert thread.thread_id == "t1" - assert thread.title == "Test Thread" - assert thread.messages == [] - - def test_with_messages(self) -> None: - msg = ChatMessage(role="user", content="Hello") - thread = ChatThread(thread_id="t1", title="Chat", messages=[msg]) - assert len(thread.messages) == 1 - - -class TestACPCommand: - """Test ACPCommand model.""" - - def test_creation(self) -> None: - cmd = ACPCommand(name="web", description="Search the web") - assert cmd.name == "web" - assert cmd.description == "Search the web" - - def test_with_input(self) -> None: - from pywry.chat.models import ACPCommandInput - - cmd = ACPCommand( - name="test", - description="Run tests", - input=ACPCommandInput(hint="Enter test name"), - ) - assert cmd.input.hint == "Enter test name" - - -class TestACPToolCall: - """Test ACPToolCall model.""" - - def test_creation(self) -> None: - tc = ACPToolCall( - toolCallId="call_1", - title="Read file", - name="fs_read", - kind="read", - status="pending", - ) - assert tc.tool_call_id == "call_1" - assert tc.kind == "read" - assert tc.status == "pending" - - def test_defaults(self) -> None: - tc = ACPToolCall(name="test") - assert tc.tool_call_id # auto-generated - assert tc.kind == "other" - assert tc.status == "pending" - - def test_with_arguments(self) -> None: - tc = ACPToolCall( - name="search", - arguments={"query": "hello"}, - ) - assert tc.arguments["query"] == "hello" - - -class TestChatConfig: - """Test ChatConfig model.""" - - def test_defaults(self) -> None: - config = ChatConfig() - assert config.model == "gpt-4" - assert config.temperature == 0.7 - assert config.max_tokens == 4096 - assert config.streaming is True - assert config.persist is False - - def test_custom_values(self) -> None: - config = ChatConfig( - system_prompt="You are helpful.", - model="claude-3", - temperature=0.3, - ) - assert config.system_prompt == "You are helpful." - assert config.model == "claude-3" - - -class TestChatWidgetConfig: - """Test ChatWidgetConfig model.""" - - def test_defaults(self) -> None: - config = ChatWidgetConfig() - assert config.title == "Chat" - assert config.height == 700 - assert config.show_sidebar is True - - def test_with_chat_config(self) -> None: - config = ChatWidgetConfig( - title="AI Assistant", - chat_config=ChatConfig(model="gpt-4o"), - ) - assert config.chat_config.model == "gpt-4o" - - -# ============================================================================= -# Content Part Tests -# ============================================================================= - - -class TestContentParts: - """Test ACP ContentBlock types.""" - - def test_text_part(self) -> None: - part = TextPart(text="hello") - assert part.type == "text" - assert part.text == "hello" - - def test_text_part_with_annotations(self) -> None: - part = TextPart(text="hello", annotations={"source": "llm"}) - assert part.annotations["source"] == "llm" - - def test_image_part(self) -> None: - part = ImagePart(data="base64data", mimeType="image/png") - assert part.type == "image" - assert part.data == "base64data" - assert part.mime_type == "image/png" - - def test_audio_part(self) -> None: - part = AudioPart(data="audiodata", mimeType="audio/wav") - assert part.type == "audio" - assert part.mime_type == "audio/wav" - - def test_resource_link_part(self) -> None: - part = ResourceLinkPart( - uri="pywry://resource/1", - name="Doc", - title="My Document", - size=1024, - ) - assert part.type == "resource_link" - assert part.name == "Doc" - assert part.title == "My Document" - assert part.size == 1024 - - def test_embedded_resource_part(self) -> None: - part = EmbeddedResourcePart( - resource=EmbeddedResource( - uri="file:///doc.txt", - mimeType="text/plain", - text="Hello world", - ), - ) - assert part.type == "resource" - assert part.resource.text == "Hello world" - - -# ============================================================================= -# GenerationHandle Tests -# ============================================================================= - - -class TestGenerationHandle: - """Test GenerationHandle dataclass.""" - - def test_creation(self) -> None: - handle = GenerationHandle( - message_id="msg_1", - widget_id="w_1", - thread_id="t_1", - ) - assert handle.message_id == "msg_1" - assert not handle.cancel_event.is_set() - - def test_cancel(self) -> None: - handle = GenerationHandle( - message_id="msg_1", - widget_id="w_1", - thread_id="t_1", - ) - handle.cancel() - assert handle.cancel_event.is_set() - - def test_append_chunk(self) -> None: - handle = GenerationHandle( - message_id="msg_1", - widget_id="w_1", - thread_id="t_1", - ) - handle.append_chunk("Hello ") - handle.append_chunk("world") - assert handle.partial_content == "Hello world" - - def test_append_after_cancel_is_noop(self) -> None: - handle = GenerationHandle( - message_id="msg_1", - widget_id="w_1", - thread_id="t_1", - ) - handle.append_chunk("before") - handle.cancel() - handle.append_chunk("after") - assert handle.partial_content == "before" - - def test_is_expired(self) -> None: - handle = GenerationHandle( - message_id="msg_1", - widget_id="w_1", - thread_id="t_1", - ) - assert not handle.is_expired - handle.created_at = time.time() - GENERATION_HANDLE_TTL - 1 - assert handle.is_expired - - # ============================================================================= # ChatStateMixin Tests # ============================================================================= diff --git a/pywry/tests/test_chat_manager.py b/pywry/tests/test_chat_manager.py index 6a154c5..c90678a 100644 --- a/pywry/tests/test_chat_manager.py +++ b/pywry/tests/test_chat_manager.py @@ -16,10 +16,18 @@ - _on_request_state emits full initialization state - _on_settings_change_event updates internal state - _on_slash_command_event handles /clear + delegates to user callback +- Artifact dispatch for every artifact type +- Asset injection for AG Grid / Plotly / TradingView (emit + anywidget paths) +- @-context attachments and auto-attached widget context +- Edit/Resend flows including provider integration """ from __future__ import annotations +import asyncio +import builtins +import pathlib +import threading import time from typing import Any @@ -36,14 +44,17 @@ PlotlyArtifact, TableArtifact, TradingViewArtifact, + TradingViewSeries, ) from pywry.chat.manager import ( Attachment, ChatContext, ChatManager, SettingsItem, + _StreamState, + _tool_result_text, ) -from pywry.chat.session import PlanEntry +from pywry.chat.session import AgentCapabilities, PlanEntry from pywry.chat.updates import ( AgentMessageUpdate, ArtifactUpdate, @@ -56,7 +67,7 @@ # ============================================================================= -# Fixtures +# Module-level fixtures and helpers # ============================================================================= @@ -82,6 +93,36 @@ def clear(self) -> None: self.events.clear() +class FakeWidgetNoEmitFire: + """Widget without ``emit_fire`` — exercises the fallback in ``_emit_fire``.""" + + def __init__(self) -> None: + self.events: list[tuple[str, dict]] = [] + + def emit(self, event_type: str, data: dict[str, Any]) -> None: + self.events.append((event_type, data)) + + def get_events(self, event_type: str) -> list[dict]: + return [d for e, d in self.events if e == event_type] + + +class FakeAnywidget: + """Stand-in for an anywidget-style widget — captures ``set_trait`` calls.""" + + def __init__(self) -> None: + self.traits: dict[str, Any] = {} + self.events: list[tuple[str, dict]] = [] + + def emit(self, event_type: str, data: dict[str, Any]) -> None: + self.events.append((event_type, data)) + + def emit_fire(self, event_type: str, data: dict[str, Any]) -> None: + self.events.append((event_type, data)) + + def set_trait(self, key: str, value: Any) -> None: + self.traits[key] = value + + def echo_handler(messages, ctx): """Simple handler that returns the last user message.""" return f"Echo: {messages[-1]['text']}" @@ -113,6 +154,26 @@ def rich_handler(messages, ctx): yield AgentMessageUpdate(text="Done!") +class _MinimalAsyncProvider: + """Minimal async-iterable provider used as a stand-in for ChatProvider. + + Subclasses override :meth:`prompt` to yield the updates the test needs. + """ + + async def initialize(self, _caps): + return AgentCapabilities() + + async def new_session(self, _cwd, mcp_servers=None): + return "sid" + + async def cancel(self, _sid): + return None + + async def prompt(self, _sid, _content, _cancel_event=None): + if False: + yield # pragma: no cover + + @pytest.fixture def widget(): return FakeWidget() @@ -142,6 +203,43 @@ def bound_manager(widget): return mgr +def _seed_thread(mgr: ChatManager) -> tuple[str, list[dict[str, Any]]]: + """Populate the manager's active thread with a four-message conversation.""" + tid = mgr.active_thread_id + msgs = [ + {"id": "msg_user_1", "role": "user", "text": "first question"}, + {"id": "msg_asst_1", "role": "assistant", "text": "first answer"}, + {"id": "msg_user_2", "role": "user", "text": "second question"}, + {"id": "msg_asst_2", "role": "assistant", "text": "second answer"}, + ] + mgr._threads[tid] = list(msgs) + return tid, msgs + + +# ============================================================================= +# Module helpers +# ============================================================================= + + +class TestToolResultText: + """Test the _tool_result_text content flattener.""" + + def test_string(self): + assert _tool_result_text("hi") == "hi" + + def test_list_of_text_parts(self): + result = _tool_result_text([{"type": "text", "text": "a"}, {"type": "text", "text": "b"}]) + assert result == "ab" + + def test_list_with_strings(self): + result = _tool_result_text(["str", {"type": "text", "text": "x"}]) + assert result == "strx" + + def test_other_returns_empty(self): + assert _tool_result_text(42) == "" + assert _tool_result_text(None) == "" + + # ============================================================================= # Update Type Tests # ============================================================================= @@ -229,8 +327,6 @@ def test_json_artifact(self): assert a.artifact_type == "json" def test_tradingview_artifact(self): - from pywry.chat.artifacts import TradingViewSeries - a = TradingViewArtifact( title="AAPL", series=[TradingViewSeries(type="candlestick", data=[])], @@ -259,18 +355,19 @@ def test_attachment_summary_empty(self): ctx = ChatContext() assert ctx.attachment_summary == "" - def test_attachment_summary_file(self): - import pathlib - + def test_attachment_summary_file_with_path(self): ctx = ChatContext( attachments=[ Attachment(type="file", name="report.csv", path=pathlib.Path("/data/report.csv")), ] ) assert "report.csv" in ctx.attachment_summary - assert "report.csv" in ctx.attachment_summary assert str(pathlib.Path("/data/report.csv")) in ctx.attachment_summary + def test_attachment_summary_file_without_path(self): + ctx = ChatContext(attachments=[Attachment(type="file", name="a.csv")]) + assert "a.csv (file)" in ctx.attachment_summary + def test_attachment_summary_widget(self): ctx = ChatContext( attachments=[ @@ -279,7 +376,10 @@ def test_attachment_summary_widget(self): ) assert "@Sales Data" in ctx.attachment_summary - def test_context_text(self): + def test_context_text_empty(self): + assert ChatContext().context_text == "" + + def test_context_text_widget(self): ctx = ChatContext( attachments=[ Attachment(type="widget", name="@Grid", content="col1,col2\n1,2"), @@ -289,7 +389,17 @@ def test_context_text(self): assert "Grid" in text assert "col1,col2" in text - def test_get_attachment_found(self): + def test_context_text_file_with_path(self): + ctx = ChatContext( + attachments=[ + Attachment(type="file", name="data.csv", path=pathlib.Path("/tmp/data.csv")), + ] + ) + text = ctx.context_text + assert "data.csv" in text + assert "Path:" in text + + def test_get_attachment_widget(self): ctx = ChatContext( attachments=[ Attachment(type="widget", name="@Sales", content="revenue=100"), @@ -298,6 +408,15 @@ def test_get_attachment_found(self): assert ctx.get_attachment("Sales") == "revenue=100" assert ctx.get_attachment("@Sales") == "revenue=100" + def test_get_attachment_file_returns_path_string(self): + ctx = ChatContext( + attachments=[ + Attachment(type="file", name="data.csv", path=pathlib.Path("/tmp/data.csv")), + ] + ) + # File attachments resolve to the path string + assert "data.csv" in ctx.get_attachment("data.csv") + def test_get_attachment_not_found(self): ctx = ChatContext(attachments=[]) result = ctx.get_attachment("Missing") @@ -306,8 +425,20 @@ def test_get_attachment_not_found(self): def test_wait_for_input_cancel(self): ctx = ChatContext() ctx.cancel_event.set() - result = ctx.wait_for_input(timeout=0.1) - assert result == "" + assert ctx.wait_for_input(timeout=0.1) == "" + + def test_wait_for_input_timeout(self): + ctx = ChatContext() + start = time.time() + assert ctx.wait_for_input(timeout=0.05) == "" + assert (time.time() - start) < 1.0 + + def test_wait_for_input_returns_response(self): + ctx = ChatContext() + ctx._input_response = "answer" + ctx._input_event.set() + assert ctx.wait_for_input() == "answer" + assert not ctx._input_event.is_set() # ============================================================================= @@ -337,11 +468,11 @@ def test_range(self): # ============================================================================= -# ChatManager Tests +# Construction / configuration # ============================================================================= -class TestChatManager: +class TestChatManagerInit: """Test ChatManager construction and public API.""" def test_construction(self): @@ -380,6 +511,25 @@ def test_settings_property(self): ) assert mgr.settings["model"] == "gpt-4" + def test_file_attach_without_accept_types_raises(self): + with pytest.raises(ValueError, match="file_accept_types is required"): + ChatManager(handler=echo_handler, enable_file_attach=True) + + def test_file_attach_with_accept_types_ok(self): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], + ) + assert mgr._file_accept_types == [".csv"] + + def test_context_allowed_roots_resolved(self, tmp_path): + mgr = ChatManager( + handler=echo_handler, + context_allowed_roots=[str(tmp_path)], + ) + assert mgr._context_allowed_roots == [str(pathlib.Path(tmp_path).resolve())] + def test_send_message(self, bound_manager, widget): bound_manager.send_message("Hello from code") events = widget.get_events("chat:assistant-message") @@ -393,6 +543,58 @@ def test_send_message_stores_in_thread(self, bound_manager): assert bound_manager.threads[tid][0]["text"] == "stored" +class TestBind: + """Test the ``bind()`` method and its lazy anywidget import.""" + + def test_bind_to_fake_widget_sets_widget(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + assert mgr._widget is widget + # FakeWidget isn't an anywidget instance + assert mgr._is_anywidget is False + + def test_bind_falls_back_when_anywidget_missing(self, monkeypatch): + import sys + + mgr = ChatManager(handler=echo_handler) + # The lazy ``from ..widget import PyWryChatWidget`` inside bind() must + # raise to hit the fallback. Force the cached pywry.widget module to + # raise on the PyWryChatWidget lookup. + cached = sys.modules.get("pywry.widget") + if cached is not None: + # Create a shim that raises AttributeError → caught as ImportError + # path requires a real ImportError; replace the module with a + # module-like object whose attribute access fails. + class _BrokenWidgetModule: + def __getattr__(self, name): + if name == "PyWryChatWidget": + raise ImportError("PyWryChatWidget missing") + return getattr(cached, name) + + monkeypatch.setitem(sys.modules, "pywry.widget", _BrokenWidgetModule()) + + widget = FakeWidget() + mgr.bind(widget) + assert mgr._widget is widget + assert mgr._is_anywidget is False + + +class TestToolbarMethod: + """Test the ``toolbar()`` factory.""" + + def test_returns_toolbar_instance(self): + from pywry.toolbar import Toolbar + + mgr = ChatManager(handler=echo_handler) + tb = mgr.toolbar() + assert isinstance(tb, Toolbar) + + +# ============================================================================= +# Handler dispatch +# ============================================================================= + + class TestChatManagerHandlerDispatch: """Test handler invocation and stream processing.""" @@ -404,7 +606,6 @@ def test_echo_handler(self, widget): "chat:user-message", "", ) - # Wait for background thread time.sleep(0.3) events = widget.get_events("chat:assistant-message") assert any("Echo: hello" in e.get("text", "") for e in events) @@ -419,7 +620,6 @@ def test_stream_handler(self, widget): ) time.sleep(0.3) chunks = widget.get_events("chat:stream-chunk") - # Should have streaming chunks + done assert len(chunks) > 0 done_chunks = [c for c in chunks if c.get("done")] assert len(done_chunks) >= 1 @@ -450,281 +650,543 @@ def slow_handler(messages, ctx): stopped = [c for c in chunks if c.get("stopped")] assert len(stopped) >= 1 + def test_handler_exception_emits_error_message(self, widget): + def bad_handler(messages, ctx): + raise ValueError("oops") -class TestChatManagerThreads: - """Test thread CRUD operations.""" + mgr = ChatManager(handler=bad_handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.3) + msgs = widget.get_events("chat:assistant-message") + assert any("Error: oops" in m.get("text", "") for m in msgs) - def test_create_thread(self, bound_manager, widget): - bound_manager._on_thread_create({"title": "New Thread"}, "", "") - events = widget.get_events("chat:update-thread-list") - assert len(events) >= 1 - assert len(bound_manager.threads) == 2 + def test_user_message_empty_text_returns(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": " ", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + # No assistant message generated, no typing indicator emitted + assert widget.get_events("chat:assistant-message") == [] + assert widget.get_events("chat:typing-indicator") == [] - def test_switch_thread(self, bound_manager, widget): - bound_manager._on_thread_create({"title": "Thread 2"}, "", "") - new_tid = bound_manager.active_thread_id - old_tid = next(t for t in bound_manager.threads if t != new_tid) - bound_manager._on_thread_switch({"threadId": old_tid}, "", "") - assert bound_manager.active_thread_id == old_tid - def test_delete_thread(self, bound_manager, widget): - bound_manager._on_thread_create({"title": "To Delete"}, "", "") - tid = bound_manager.active_thread_id - bound_manager._on_thread_delete({"threadId": tid}, "", "") - assert tid not in bound_manager.threads +class TestHandlerResultDispatch: + """Cover the four return-value paths for handlers.""" - def test_rename_thread(self, bound_manager, widget): - tid = bound_manager.active_thread_id - bound_manager._on_thread_rename({"threadId": tid, "title": "Renamed"}, "", "") - events = widget.get_events("chat:update-thread-list") - assert len(events) >= 1 + def test_coroutine_returning_string(self, widget): + async def handler(messages, ctx): + return "async-result" + mgr = ChatManager(handler=handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.4) + msgs = widget.get_events("chat:assistant-message") + assert any("async-result" in m.get("text", "") for m in msgs) -class TestChatManagerState: - """Test state management.""" + def test_coroutine_returning_async_generator(self, widget): + async def gen(): + yield "x" + yield "y" - def test_request_state(self, bound_manager, widget): - bound_manager._on_request_state({}, "", "") - events = widget.get_events("chat:state-response") - assert len(events) == 1 - state = events[0] - assert "threads" in state - assert "activeThreadId" in state + async def handler(messages, ctx): + return gen() - def test_request_state_with_welcome(self, widget): - mgr = ChatManager(handler=echo_handler, welcome_message="Welcome!") + mgr = ChatManager(handler=handler) mgr.bind(widget) - mgr._on_request_state({}, "", "") - events = widget.get_events("chat:state-response") - assert len(events) == 1 - messages = events[0]["messages"] - assert any("Welcome!" in m.get("content", "") for m in messages) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.4) + chunks = widget.get_events("chat:stream-chunk") + assert chunks - def test_settings_change(self, bound_manager, widget): - callback = MagicMock() - bound_manager._on_settings_change = callback - bound_manager._on_settings_change_event({"key": "model", "value": "gpt-4o"}, "", "") - assert bound_manager.settings["model"] == "gpt-4o" - callback.assert_called_once_with("model", "gpt-4o") + def test_coroutine_returning_other_type(self, widget): + async def handler(messages, ctx): + return 42 - def test_slash_command_clear(self, bound_manager, widget): - tid = bound_manager.active_thread_id - bound_manager.send_message("test") - assert len(bound_manager.threads[tid]) == 1 - bound_manager._on_slash_command_event({"command": "/clear", "threadId": tid}, "", "") - assert len(bound_manager.threads[tid]) == 0 + mgr = ChatManager(handler=handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.4) + msgs = widget.get_events("chat:assistant-message") + assert any("42" in m.get("text", "") for m in msgs) + def test_sync_iterator_handler(self, widget): + def handler(messages, ctx): + return iter(["x", "y"]) -# ============================================================================= -# Edit / Resend Tests -# ============================================================================= + mgr = ChatManager(handler=handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.3) + chunks = widget.get_events("chat:stream-chunk") + assert chunks + def test_async_generator_returned_directly(self, widget): + async def agen(): + yield "x" -def _seed_thread(mgr: ChatManager) -> tuple[str, list[dict[str, Any]]]: - """Populate the manager's active thread with a four-message conversation.""" - tid = mgr.active_thread_id - msgs = [ - {"id": "msg_user_1", "role": "user", "text": "first question"}, - {"id": "msg_asst_1", "role": "assistant", "text": "first answer"}, - {"id": "msg_user_2", "role": "user", "text": "second question"}, - {"id": "msg_asst_2", "role": "assistant", "text": "second answer"}, - ] - mgr._threads[tid] = list(msgs) - return tid, msgs + def handler(messages, ctx): + return agen() + mgr = ChatManager(handler=handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.3) + chunks = widget.get_events("chat:stream-chunk") + assert chunks -class TestTruncateThreadAt: - """Direct unit tests for the _truncate_thread_at helper.""" + def test_other_return_type_stringified(self, widget): + def handler(messages, ctx): + return 3.14 - def test_keep_target_drops_messages_after(self, bound_manager): - tid, _ = _seed_thread(bound_manager) - removed, removed_ids = bound_manager._truncate_thread_at( - tid, "msg_user_2", keep_target=True + mgr = ChatManager(handler=handler) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", ) - kept_ids = [m["id"] for m in bound_manager._threads[tid]] - assert kept_ids == ["msg_user_1", "msg_asst_1", "msg_user_2"] - assert removed_ids == ["msg_asst_2"] - assert len(removed) == 1 + time.sleep(0.3) + msgs = widget.get_events("chat:assistant-message") + assert any("3.14" in m.get("text", "") for m in msgs) - def test_drop_target_removes_message_and_after(self, bound_manager): - tid, _ = _seed_thread(bound_manager) - removed, removed_ids = bound_manager._truncate_thread_at( - tid, "msg_user_2", keep_target=False - ) - kept_ids = [m["id"] for m in bound_manager._threads[tid]] - assert kept_ids == ["msg_user_1", "msg_asst_1"] - assert removed_ids == ["msg_user_2", "msg_asst_2"] - assert len(removed) == 2 - def test_unknown_message_id_no_op(self, bound_manager): - tid, msgs = _seed_thread(bound_manager) - removed, removed_ids = bound_manager._truncate_thread_at(tid, "ghost", keep_target=True) - assert removed == [] - assert removed_ids == [] - # Thread untouched - assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] +class TestStreamCancelPaths: + """Cancellation paths in the sync/async stream handlers.""" + def test_handle_stream_cancel_mid_stream(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) -class TestEditMessage: - """Tests for _on_edit_message — replace text + truncate + regenerate.""" + cancel = threading.Event() - def test_edit_emits_messages_deleted(self, bound_manager, widget): - tid, _ = _seed_thread(bound_manager) - bound_manager._on_edit_message( - {"messageId": "msg_user_2", "threadId": tid, "text": "REVISED"}, - "chat:edit-message", - "", - ) - time.sleep(0.2) # allow background thread - deletions = widget.get_events("chat:messages-deleted") - assert deletions, "expected at least one chat:messages-deleted event" - d = deletions[0] - assert d["editedMessageId"] == "msg_user_2" - assert d["editedText"] == "REVISED" - # Only the trailing assistant reply should be removed - assert d["messageIds"] == ["msg_asst_2"] + def gen(): + yield "a" + cancel.set() + yield "b" - def test_edit_replaces_user_message_text(self, bound_manager, widget): - tid, _ = _seed_thread(bound_manager) - bound_manager._on_edit_message( - {"messageId": "msg_user_2", "threadId": tid, "text": "REVISED"}, - "chat:edit-message", - "", - ) - time.sleep(0.2) - thread = bound_manager._threads[tid] - # Find the edited user message - edited = next(m for m in thread if m.get("id") == "msg_user_2") - assert edited["text"] == "REVISED" + mgr._handle_stream(gen(), "msg-1", "thread-1", cancel) + chunks = widget.get_events("chat:stream-chunk") + stopped = [c for c in chunks if c.get("stopped")] + assert stopped - def test_edit_unknown_message_is_noop(self, bound_manager, widget): - tid, msgs = _seed_thread(bound_manager) - bound_manager._on_edit_message( - {"messageId": "ghost", "threadId": tid, "text": "x"}, - "chat:edit-message", - "", - ) - # No deletion event, no thread mutation - assert widget.get_events("chat:messages-deleted") == [] - assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + async def test_handle_async_stream_cancel_path(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) - def test_edit_empty_text_is_noop(self, bound_manager, widget): - tid, _ = _seed_thread(bound_manager) - bound_manager._on_edit_message( - {"messageId": "msg_user_2", "threadId": tid, "text": " "}, - "chat:edit-message", - "", - ) - assert widget.get_events("chat:messages-deleted") == [] + cancel = threading.Event() + async def agen(): + yield "x" + cancel.set() + yield "y" -class TestResendFrom: - """Tests for _on_resend_from — drop target + everything after, regenerate.""" + await mgr._handle_async_stream(agen(), "msg-1", "thread-1", cancel) + chunks = widget.get_events("chat:stream-chunk") + stopped = [c for c in chunks if c.get("stopped")] + assert stopped - def test_resend_keeps_target_and_drops_only_later_messages(self, bound_manager, widget): - tid, _ = _seed_thread(bound_manager) - bound_manager._on_resend_from( - {"messageId": "msg_user_2", "threadId": tid}, - "chat:resend-from", - "", + async def test_handle_async_stream_no_items_still_finalizes(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + cancel = threading.Event() + + async def agen(): + if False: + yield # pragma: no cover + + await mgr._handle_async_stream(agen(), "msg-1", "thread-1", cancel) + # The async stream must still emit stream-done even on empty output + chunks = widget.get_events("chat:stream-chunk") + done = [c for c in chunks if c.get("done")] + assert done + + +# ============================================================================= +# Asset injection +# ============================================================================= + + +class TestAssetInjection: + """Test lazy asset injection for AG Grid / Plotly / TradingView.""" + + def test_aggrid_assets_inject_via_emit(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + mgr._inject_aggrid_assets() + events = widget.get_events("chat:load-assets") + assert len(events) == 1 + assert events[0]["scripts"] + # Idempotent + mgr._inject_aggrid_assets() + assert len(widget.get_events("chat:load-assets")) == 1 + + def test_plotly_assets_inject_via_emit(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + mgr._inject_plotly_assets() + events = widget.get_events("chat:load-assets") + assert events + assert events[0]["scripts"] + + def test_plotly_assets_idempotent_when_include_plotly_true(self, widget): + # include_plotly=True marks _plotly_assets_sent=True at init — + # no load-assets event when we call inject afterward. + mgr = ChatManager(handler=echo_handler, include_plotly=True) + mgr.bind(widget) + mgr._inject_plotly_assets() + assert widget.get_events("chat:load-assets") == [] + + def test_tradingview_assets_inject_via_emit(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + mgr._inject_tradingview_assets() + assert widget.get_events("chat:load-assets") + # Idempotent + widget.clear() + mgr._inject_tradingview_assets() + assert widget.get_events("chat:load-assets") == [] + + def test_anywidget_aggrid_uses_set_trait(self): + mgr = ChatManager(handler=echo_handler) + w = FakeAnywidget() + mgr.bind(w) + mgr._is_anywidget = True + mgr._inject_aggrid_assets() + assert "_asset_js" in w.traits + assert "_asset_css" in w.traits + + def test_anywidget_plotly_uses_set_trait(self): + mgr = ChatManager(handler=echo_handler) + w = FakeAnywidget() + mgr.bind(w) + mgr._is_anywidget = True + mgr._inject_plotly_assets() + assert "_asset_js" in w.traits + + def test_anywidget_tradingview_uses_set_trait(self): + mgr = ChatManager(handler=echo_handler) + w = FakeAnywidget() + mgr.bind(w) + mgr._is_anywidget = True + mgr._inject_tradingview_assets() + assert "_asset_js" in w.traits + + +# ============================================================================= +# Artifact dispatch +# ============================================================================= + + +class TestArtifactDispatch: + """Test the _dispatch_artifact method for every artifact type.""" + + def _setup(self): + mgr = ChatManager(handler=echo_handler) + widget = FakeWidget() + mgr.bind(widget) + return mgr, widget + + def test_dispatch_code(self): + mgr, widget = self._setup() + artifact = CodeArtifact(title="t", content="x = 1", language="python") + mgr._dispatch_artifact(artifact, "msg-1", "thread-1") + events = widget.get_events("chat:artifact") + assert events[0]["content"] == "x = 1" + assert events[0]["language"] == "python" + + def test_dispatch_markdown(self): + mgr, widget = self._setup() + mgr._dispatch_artifact(MarkdownArtifact(title="t", content="# Hi"), "m", "t") + assert widget.get_events("chat:artifact")[0]["content"] == "# Hi" + + def test_dispatch_html(self): + mgr, widget = self._setup() + mgr._dispatch_artifact(HtmlArtifact(title="t", content="x"), "m", "t") + assert "x" in widget.get_events("chat:artifact")[0]["content"] + + def test_dispatch_table(self): + mgr, widget = self._setup() + artifact = TableArtifact(title="t", data=[{"a": 1, "b": 2}]) + mgr._dispatch_artifact(artifact, "m", "t") + events = widget.get_events("chat:artifact") + assert events + assert events[0]["rowData"] + assert events[0]["columns"] + + def test_dispatch_table_with_column_defs_and_options(self): + mgr, widget = self._setup() + artifact = TableArtifact( + title="t", + data=[{"a": 1}], + column_defs=[{"field": "a"}], + grid_options={"rowHeight": 50}, ) - time.sleep(0.2) - deletions = widget.get_events("chat:messages-deleted") - assert deletions - d = deletions[0] - # The target user message stays — only the assistant reply (and - # any subsequent turns) are dropped so "Resend" doesn't read as - # "your message was erased". - assert d["messageIds"] == ["msg_asst_2"] - # No edited-message flags — the user message text is unchanged, - # so the frontend doesn't need to re-render its content. - assert "editedMessageId" not in d - assert "editedText" not in d - # Server-side thread keeps the target user message in place. - surviving_ids = [m["id"] for m in bound_manager._threads[tid]] - assert "msg_user_2" in surviving_ids - assert "msg_asst_2" not in surviving_ids + mgr._dispatch_artifact(artifact, "m", "t") + events = widget.get_events("chat:artifact") + assert events[0]["columnDefs"] == [{"field": "a"}] + assert events[0]["gridOptions"] == {"rowHeight": 50} + + def test_dispatch_plotly(self): + mgr, widget = self._setup() + artifact = PlotlyArtifact(title="t", figure={"data": [], "layout": {}}) + mgr._dispatch_artifact(artifact, "m", "t") + events = widget.get_events("chat:artifact") + assert events + assert events[0]["figure"]["data"] == [] + + def test_dispatch_tradingview(self): + mgr, widget = self._setup() + artifact = TradingViewArtifact( + title="t", + series=[TradingViewSeries(type="candlestick", data=[])], + options={"timezone": "UTC"}, + ) + mgr._dispatch_artifact(artifact, "m", "t") + events = widget.get_events("chat:artifact") + assert events + assert events[0]["options"] == {"timezone": "UTC"} + assert events[0]["series"][0]["type"] == "candlestick" + + def test_dispatch_image(self): + mgr, widget = self._setup() + mgr._dispatch_artifact( + ImageArtifact(title="t", url="https://example.com/x.png", alt="x"), "m", "t" + ) + assert widget.get_events("chat:artifact")[0]["url"] == "https://example.com/x.png" - def test_resend_re_runs_handler_with_same_text(self, bound_manager, widget): - tid, _ = _seed_thread(bound_manager) - bound_manager._on_resend_from( - {"messageId": "msg_user_2", "threadId": tid}, - "chat:resend-from", - "", + def test_dispatch_json(self): + mgr, widget = self._setup() + mgr._dispatch_artifact(JsonArtifact(title="t", data={"k": "v"}), "m", "t") + assert widget.get_events("chat:artifact")[0]["data"] == {"k": "v"} + + +# ============================================================================= +# Session-update dispatcher +# ============================================================================= + + +class TestSessionUpdateDispatch: + """_dispatch_session_update routes each update type to the right event.""" + + def _setup(self): + mgr = ChatManager(handler=echo_handler) + widget = FakeWidget() + mgr.bind(widget) + return mgr, widget, _StreamState("m") + + def test_dispatch_status(self): + mgr, widget, state = self._setup() + mgr._dispatch_session_update(StatusUpdate(text="working"), state, "t", None) + events = widget.get_events("chat:status-update") + assert events and events[0]["text"] == "working" + + def test_dispatch_thinking(self): + mgr, widget, state = self._setup() + mgr._dispatch_session_update(ThinkingUpdate(text="hmm"), state, "t", None) + events = widget.get_events("chat:thinking-chunk") + assert events and events[0]["text"] == "hmm" + + def test_dispatch_citation(self): + mgr, widget, state = self._setup() + mgr._dispatch_session_update( + CitationUpdate(url="https://x", title="T", snippet="s"), state, "t", None ) - time.sleep(0.3) - # Echo handler returns "Echo: " — verify the assistant reply - # came back for the resent prompt. - replies = widget.get_events("chat:assistant-message") - assert any("Echo: second question" in r.get("text", "") for r in replies) + events = widget.get_events("chat:citation") + assert events and events[0]["url"] == "https://x" + + def test_dispatch_tool_call_in_progress(self): + mgr, widget, state = self._setup() + mgr._dispatch_session_update( + ToolCallUpdate(toolCallId="t1", name="x", kind="other", status="in_progress"), + state, + "t", + None, + ) + events = widget.get_events("chat:tool-call") + assert events and events[0]["name"] == "x" + + def test_dispatch_tool_call_completed_emits_result(self): + mgr, widget, state = self._setup() + mgr._dispatch_session_update( + ToolCallUpdate( + toolCallId="t1", + name="x", + kind="other", + status="completed", + content=[{"type": "text", "text": "result"}], + ), + state, + "t", + None, + ) + events = widget.get_events("chat:tool-result") + assert events + assert events[0]["result"] == "result" + + def test_dispatch_artifact_update(self): + mgr, widget, state = self._setup() + mgr._dispatch_session_update( + ArtifactUpdate(artifact=CodeArtifact(title="x", content="x = 1", language="python")), + state, + "t", + None, + ) + assert widget.get_events("chat:artifact") - def test_resend_unknown_message_is_noop(self, bound_manager, widget): - tid, msgs = _seed_thread(bound_manager) - bound_manager._on_resend_from( - {"messageId": "ghost", "threadId": tid}, - "chat:resend-from", - "", + def test_process_handler_item_artifact_passes_through(self): + mgr, widget, state = self._setup() + # ArtifactBase instances dispatched directly (not wrapped in ArtifactUpdate) + mgr._process_handler_item(MarkdownArtifact(title="x", content="# Hi"), state, "t", None) + assert widget.get_events("chat:artifact") + + def test_process_handler_item_string_buffers(self): + mgr, widget, state = self._setup() + mgr._process_handler_item("hello", state, "t", None) + # STREAM_FLUSH_INTERVAL=0 → flushed immediately + assert widget.get_events("chat:stream-chunk") + + def test_process_handler_item_plain_object_silent(self): + mgr, widget, state = self._setup() + # Non-string, non-artifact, no .session_update attr — silently dispatches + # to _dispatch_session_update which has no branch for it. + mgr._process_handler_item(object(), state, "t", None) + # Nothing visible came out + assert widget.get_events("chat:stream-chunk") == [] + + +# ============================================================================= +# _is_accepted_file +# ============================================================================= + + +class TestIsAcceptedFile: + def test_no_filter_accepts_all(self): + mgr = ChatManager(handler=echo_handler) + assert mgr._is_accepted_file("anything.zip") is True + + def test_extension_accepted(self): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv", ".json"], ) - assert widget.get_events("chat:messages-deleted") == [] - # Original thread is intact - assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + assert mgr._is_accepted_file("data.csv") is True + assert mgr._is_accepted_file("schema.json") is True - def test_resend_targeting_assistant_message_is_noop(self, bound_manager, widget): - """Only user messages can be resent; assistant ids are ignored.""" - tid, msgs = _seed_thread(bound_manager) - bound_manager._on_resend_from( - {"messageId": "msg_asst_1", "threadId": tid}, - "chat:resend-from", - "", + def test_extension_rejected(self): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], ) - assert widget.get_events("chat:messages-deleted") == [] - assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + assert mgr._is_accepted_file("malware.exe") is False -class TestUserMessageStoresId: - """The frontend-generated messageId must round-trip into thread storage.""" +# ============================================================================= +# _emit_fire fallback +# ============================================================================= - def test_user_message_uses_provided_id(self, bound_manager): - tid = bound_manager.active_thread_id - bound_manager._on_user_message( - {"messageId": "msg_provided_42", "text": "hi", "threadId": tid}, - "chat:user-message", - "", + +class TestEmitFireFallback: + def test_widget_without_emit_fire_falls_back_to_emit(self): + mgr = ChatManager(handler=echo_handler) + widget = FakeWidgetNoEmitFire() + mgr.bind(widget) + mgr._emit_fire("chat:test", {"x": 1}) + assert widget.get_events("chat:test") + + +# ============================================================================= +# Attachment resolution +# ============================================================================= + + +class TestResolveAttachments: + """Test the public _resolve_attachments method.""" + + def test_file_with_path_resolved(self, tmp_path): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], ) - time.sleep(0.2) - first = bound_manager._threads[tid][0] - assert first["id"] == "msg_provided_42" - assert first["role"] == "user" + f = tmp_path / "x.csv" + f.write_text("a,b\n1,2") + result = mgr._resolve_attachments([{"type": "file", "name": "x.csv", "path": str(f)}]) + assert len(result) == 1 + assert result[0].type == "file" + assert result[0].path == f + + def test_file_with_only_content(self): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], + ) + result = mgr._resolve_attachments([{"type": "file", "name": "x.csv", "content": "data"}]) + assert len(result) == 1 + assert result[0].content == "data" - def test_user_message_generates_id_if_absent(self, bound_manager): - tid = bound_manager.active_thread_id - bound_manager._on_user_message( - {"text": "hi", "threadId": tid}, - "chat:user-message", - "", + def test_file_rejected_extension(self): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], ) - time.sleep(0.2) - first = bound_manager._threads[tid][0] - assert first["id"].startswith("msg_") + result = mgr._resolve_attachments([{"type": "file", "name": "evil.exe", "content": "x"}]) + assert result == [] - def test_assistant_message_carries_id(self, bound_manager): - tid = bound_manager.active_thread_id - bound_manager._on_user_message( - {"text": "hi", "threadId": tid}, - "chat:user-message", - "", + def test_file_no_path_no_content_skipped(self): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], ) - time.sleep(0.3) - # Echo handler completes synchronously; assistant message should be in thread - msgs = bound_manager._threads[tid] - asst = [m for m in msgs if m.get("role") == "assistant"] - assert asst - assert asst[0]["id"].startswith("msg_") + result = mgr._resolve_attachments([{"type": "file", "name": "x.csv"}]) + assert result == [] + + def test_widget_without_id_skipped(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + result = mgr._resolve_attachments([{"type": "widget"}]) + assert result == [] + + def test_no_context_or_file_attach_returns_empty(self): + mgr = ChatManager(handler=echo_handler) + result = mgr._resolve_attachments([{"type": "file", "name": "x.csv"}]) + assert result == [] -class TestWidgetAttachmentCarriesWidgetId: - """Every @-mention attachment must surface the widget_id explicitly so an - LLM agent reading the @-context can use it directly in MCP tool calls.""" +class TestResolveWidgetAttachment: + """Test the _resolve_widget_attachment helper across binding states.""" def _make_mgr(self) -> ChatManager: m = ChatManager(handler=echo_handler, enable_context=True) @@ -738,10 +1200,8 @@ def test_registered_source_with_getdata_content_carries_widget_id(self): "chart", content="symbol: AAPL\ninterval: 1d", name="chart" ) assert att is not None - # The attachment content always starts with widget_id: first_line = att.content.splitlines()[0] assert first_line == "widget_id: chart" - # The original getData payload is preserved after the header assert "symbol: AAPL" in att.content assert att.source == "chart" assert att.name == "@chart" @@ -751,14 +1211,11 @@ def test_registered_source_without_getdata_still_carries_widget_id(self): mgr.register_context_source("chart", "chart") att = mgr._resolve_widget_attachment("chart") assert att is not None - # Even with no JS-side getData payload, widget_id is in the content assert "widget_id: chart" in att.content assert att.source == "chart" def test_unregistered_widget_id_still_yields_attachment_with_id(self): mgr = self._make_mgr() - # No register_context_source call — but the user mentioned an unknown - # widgetId from the frontend. The attachment must still surface it. att = mgr._resolve_widget_attachment("some-other-widget") assert att is not None assert "widget_id: some-other-widget" in att.content @@ -781,9 +1238,141 @@ def test_resolve_attachments_dispatches_to_widget_helper(self): att = attachments[0] assert att.type == "widget" assert att.source == "chart" - # widget_id header is the bridge between @-context and MCP tool calls assert att.content.splitlines()[0] == "widget_id: chart" + def test_no_widget_returns_attachment_with_widget_id(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + # No widget bound + att = mgr._resolve_widget_attachment("missing-id") + assert att is not None + assert "widget_id: missing-id" in att.content + + def test_widget_without_app_attribute_returns_id_only(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + mgr.bind(FakeWidget()) # Has no ._app + att = mgr._resolve_widget_attachment("missing-widget") + assert att is not None + assert "widget_id: missing-widget" in att.content + + def test_widget_with_app_no_inline_widgets(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + + class _Widget: + def __init__(self): + self._app = MagicMock() + self._app._inline_widgets = {} + + def emit(self, *_a, **_k): + pass + + mgr.bind(_Widget()) + att = mgr._resolve_widget_attachment("absent") + assert att is not None + assert "widget_id: absent" in att.content + + def test_widget_with_inline_widget_renders_html_size(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + + class _InlineWidget: + label = "MyChart" + html = "

chart

" + + class _App: + def __init__(self): + self._inline_widgets = {"chart-id": _InlineWidget()} + + class _Widget: + def __init__(self): + self._app = _App() + + def emit(self, *_a, **_k): + pass + + mgr.bind(_Widget()) + att = mgr._resolve_widget_attachment("chart-id") + assert att is not None + assert "widget_id: chart-id" in att.content + assert "MyChart" in att.content + assert "HTML widget" in att.content + + def test_resolve_widget_attachment_handles_exception(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + + class _BadWidget: + @property + def _app(self): + raise RuntimeError("boom") + + def emit(self, *_a, **_k): + pass + + mgr.bind(_BadWidget()) + att = mgr._resolve_widget_attachment("foo") + assert att is not None + assert "widget_id: foo" in att.content + + +class TestGetContextSources: + """Test the context-source enumeration used by the @-mention popup.""" + + def test_returns_registered_sources(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + mgr.register_context_source("a", "Alpha") + mgr.register_context_source("b", "Beta") + sources = mgr._get_context_sources() + ids = {s["id"] for s in sources} + assert ids == {"a", "b"} + + def test_includes_inline_widgets_not_in_registry(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + mgr.register_context_source("registered", "Reg") + + class _Inline: + label = "InlineWidget" + + class _App: + def __init__(self): + self._inline_widgets = { + "registered": _Inline(), + "auto": _Inline(), + } + + class _Widget: + def __init__(self): + self._app = _App() + + def emit(self, *_a, **_k): + pass + + mgr.bind(_Widget()) + sources = mgr._get_context_sources() + ids = {s["id"] for s in sources} + assert "auto" in ids + # Only one entry per id + assert len([s for s in sources if s["id"] == "registered"]) == 1 + + def test_handles_app_lookup_error(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + + class _BadApp: + @property + def _inline_widgets(self): + raise RuntimeError("boom") + + class _Widget: + _app = _BadApp() + + def emit(self, *_a, **_k): + pass + + mgr.bind(_Widget()) + assert mgr._get_context_sources() == [] + + +# ============================================================================= +# Auto-attached registered context sources +# ============================================================================= + class TestRegisteredContextAutoAttaches: """Every registered context source rides along on every user message @@ -799,13 +1388,10 @@ def test_no_auto_attach_when_context_disabled(self): m = ChatManager(handler=echo_handler, enable_context=False) m.bind(FakeWidget()) m.register_context_source("chart", "chart") - merged = m._auto_attach_context_sources([]) - assert merged == [] + assert m._auto_attach_context_sources([]) == [] def test_no_auto_attach_when_no_sources_registered(self): - mgr = self._make_mgr() - merged = mgr._auto_attach_context_sources([]) - assert merged == [] + assert self._make_mgr()._auto_attach_context_sources([]) == [] def test_registered_source_is_auto_attached(self): mgr = self._make_mgr() @@ -828,7 +1414,6 @@ def test_explicit_mention_takes_precedence(self): auto_attached=False, ) merged = mgr._auto_attach_context_sources([explicit]) - # Only the explicit mention survives — no duplicate assert len(merged) == 1 assert merged[0] is explicit assert merged[0].auto_attached is False @@ -853,7 +1438,6 @@ def handler(messages, ctx): time.sleep(0.3) assert captured, "handler was never invoked" attachments = captured[0] - # Exactly one auto-attached widget — the chart assert len(attachments) == 1 att = attachments[0] assert att.source == "chart" @@ -872,7 +1456,6 @@ def test_inject_context_skips_ui_card_for_auto_attachments(self, widget): "", ) time.sleep(0.3) - # No attach_widget tool-call cards from the auto-attach attach_cards = [ d for e, d in widget.events @@ -901,6 +1484,587 @@ def handler(messages, ctx): time.sleep(0.3) assert captured injected = captured[0] - # Both the @-context block and the original user text assert "widget_id: chart" in injected assert "switch to MSFT" in injected + + +class TestInjectContext: + """Direct tests for _inject_context.""" + + def test_explicit_file_attachment_emits_tool_card(self, tmp_path): + mgr = ChatManager( + handler=echo_handler, + enable_file_attach=True, + file_accept_types=[".csv"], + ) + widget = FakeWidget() + mgr.bind(widget) + + ctx = ChatContext( + attachments=[ + Attachment( + type="file", + name="data.csv", + path=tmp_path / "data.csv", + auto_attached=False, + ), + ] + ) + messages = [{"role": "user", "text": "use it"}] + result = mgr._inject_context(messages, ctx, "m", "t") + # Tool-call + tool-result cards emitted + assert widget.get_events("chat:tool-call") + assert widget.get_events("chat:tool-result") + # Last user message was prefixed with the context block + last_user = next(m for m in result if m["role"] == "user") + assert "data.csv" in last_user["text"] + + def test_widget_attachment_no_path_uses_attach_widget_card(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + widget = FakeWidget() + mgr.bind(widget) + ctx = ChatContext( + attachments=[ + Attachment( + type="widget", + name="@chart", + content="widget_id: chart\nsymbol: BTC", + auto_attached=False, + ), + ] + ) + messages = [{"role": "user", "text": "what is this?"}] + mgr._inject_context(messages, ctx, "m", "t") + cards = widget.get_events("chat:tool-call") + assert any(c.get("name") == "attach_widget" for c in cards) + results = widget.get_events("chat:tool-result") + assert any("Attached chart" in r.get("result", "") for r in results) + + def test_no_user_message_passes_through(self): + mgr = ChatManager(handler=echo_handler, enable_context=True) + mgr.bind(FakeWidget()) + ctx = ChatContext(attachments=[Attachment(type="widget", name="@x", content="ctx")]) + messages = [{"role": "system", "text": "sys"}] + # No user message — system passes through unchanged + assert mgr._inject_context(messages, ctx, "m", "t") == messages + + +# ============================================================================= +# Thread CRUD +# ============================================================================= + + +class TestChatManagerThreads: + """Test thread CRUD operations.""" + + def test_create_thread(self, bound_manager, widget): + bound_manager._on_thread_create({"title": "New Thread"}, "", "") + events = widget.get_events("chat:update-thread-list") + assert len(events) >= 1 + assert len(bound_manager.threads) == 2 + + def test_switch_thread(self, bound_manager, widget): + bound_manager._on_thread_create({"title": "Thread 2"}, "", "") + new_tid = bound_manager.active_thread_id + old_tid = next(t for t in bound_manager.threads if t != new_tid) + bound_manager._on_thread_switch({"threadId": old_tid}, "", "") + assert bound_manager.active_thread_id == old_tid + + def test_delete_thread(self, bound_manager, widget): + bound_manager._on_thread_create({"title": "To Delete"}, "", "") + tid = bound_manager.active_thread_id + bound_manager._on_thread_delete({"threadId": tid}, "", "") + assert tid not in bound_manager.threads + + def test_rename_thread(self, bound_manager, widget): + tid = bound_manager.active_thread_id + bound_manager._on_thread_rename({"threadId": tid, "title": "Renamed"}, "", "") + events = widget.get_events("chat:update-thread-list") + assert len(events) >= 1 + + def test_switch_unknown_thread_no_op(self, bound_manager): + bound_manager._on_thread_switch({"threadId": "nope"}, "chat:thread-switch", "") + assert bound_manager.active_thread_id != "nope" + + def test_switch_emits_existing_messages(self, bound_manager, widget): + # Add a thread with messages + bound_manager._on_thread_create({"title": "T2"}, "chat:thread-create", "") + new_tid = bound_manager.active_thread_id + bound_manager._threads[new_tid] = [ + {"id": "1", "role": "user", "text": "u"}, + {"id": "2", "role": "assistant", "text": "a"}, + ] + widget.clear() + bound_manager._on_thread_switch({"threadId": new_tid}, "chat:thread-switch", "") + msgs = widget.get_events("chat:assistant-message") + assert len(msgs) == 2 + assert any(m.get("role") == "user" for m in msgs) + + +# ============================================================================= +# State management / request-state +# ============================================================================= + + +class TestChatManagerState: + """Test the request-state and settings-change event handlers.""" + + def test_request_state(self, bound_manager, widget): + bound_manager._on_request_state({}, "", "") + events = widget.get_events("chat:state-response") + assert len(events) == 1 + state = events[0] + assert "threads" in state + assert "activeThreadId" in state + + def test_request_state_with_welcome(self, widget): + mgr = ChatManager(handler=echo_handler, welcome_message="Welcome!") + mgr.bind(widget) + mgr._on_request_state({}, "", "") + events = widget.get_events("chat:state-response") + assert len(events) == 1 + messages = events[0]["messages"] + assert any("Welcome!" in m.get("content", "") for m in messages) + + def test_settings_change(self, bound_manager, widget): + callback = MagicMock() + bound_manager._on_settings_change = callback + bound_manager._on_settings_change_event({"key": "model", "value": "gpt-4o"}, "", "") + assert bound_manager.settings["model"] == "gpt-4o" + callback.assert_called_once_with("model", "gpt-4o") + + def test_clear_history_setting_clears_active_thread(self, bound_manager, widget): + tid = bound_manager.active_thread_id + bound_manager._threads[tid] = [{"id": "1", "role": "user", "text": "x"}] + bound_manager._on_settings_change_event( + {"key": "clear-history", "value": True}, + "chat:settings-change", + "", + ) + assert bound_manager._threads[tid] == [] + assert widget.get_events("chat:clear") + + def test_slash_command_clear(self, bound_manager, widget): + tid = bound_manager.active_thread_id + bound_manager.send_message("test") + assert len(bound_manager.threads[tid]) == 1 + bound_manager._on_slash_command_event({"command": "/clear", "threadId": tid}, "", "") + assert len(bound_manager.threads[tid]) == 0 + + +class TestRequestStateExtra: + """Test slash-command / settings-item / context-source emission paths.""" + + def test_slash_command_models_emitted(self, widget): + from pywry.chat.models import ACPCommand + + cmds = [ACPCommand(name="web", description="search the web")] + mgr = ChatManager(handler=echo_handler, slash_commands=cmds) + mgr.bind(widget) + mgr._on_request_state({}, "chat:request-state", "") + cmd_events = widget.get_events("chat:register-command") + names = [e["name"] for e in cmd_events] + assert "web" in names + assert "/clear" in names + + def test_settings_items_registered(self, widget): + mgr = ChatManager( + handler=echo_handler, + settings=[SettingsItem(id="model", label="Model", type="select", options=["a", "b"])], + ) + mgr.bind(widget) + mgr._on_request_state({}, "chat:request-state", "") + settings_events = widget.get_events("chat:register-settings-item") + assert len(settings_events) == 1 + assert settings_events[0]["id"] == "model" + + def test_context_sources_emitted_when_enabled(self, widget): + mgr = ChatManager(handler=echo_handler, enable_context=True) + mgr.register_context_source("c1", "Chart") + mgr.bind(widget) + mgr._on_request_state({}, "chat:request-state", "") + events = widget.get_events("chat:context-sources") + assert events + assert events[0]["sources"][0]["id"] == "c1" + + +class TestSmallEventHandlers: + """todo-clear / input-response / slash-command delegation.""" + + def test_todo_clear_emits_empty_items(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + mgr._on_todo_clear({}, "chat:todo-clear", "") + assert widget.get_events("chat:todo-update") == [{"items": []}] + + def test_input_response_no_pending_request_no_op(self): + mgr = ChatManager(handler=echo_handler) + # No pending input — silently dropped + mgr._on_input_response( + {"text": "hi", "requestId": "req-x", "threadId": mgr.active_thread_id}, + "chat:input-response", + "", + ) + + def test_input_response_resolves_pending(self): + mgr = ChatManager(handler=echo_handler) + ctx = ChatContext() + mgr._pending_inputs["req-1"] = {"ctx": ctx} + mgr._on_input_response( + {"text": "answer", "requestId": "req-1", "threadId": mgr.active_thread_id}, + "chat:input-response", + "", + ) + assert ctx._input_response == "answer" + assert ctx._input_event.is_set() + assert mgr._threads[mgr.active_thread_id] + + def test_slash_command_delegates_to_user_handler(self): + seen: list = [] + + def on_slash(name, args, tid): + seen.append((name, args, tid)) + + mgr = ChatManager(handler=echo_handler, on_slash_command=on_slash) + mgr.bind(FakeWidget()) + mgr._on_slash_command_event( + {"command": "/foo", "args": "x y", "threadId": mgr.active_thread_id}, + "chat:slash-command", + "", + ) + assert seen == [("/foo", "x y", mgr.active_thread_id)] + + +# ============================================================================= +# Edit / Resend Tests +# ============================================================================= + + +class TestTruncateThreadAt: + """Direct unit tests for the _truncate_thread_at helper.""" + + def test_keep_target_drops_messages_after(self, bound_manager): + tid, _ = _seed_thread(bound_manager) + removed, removed_ids = bound_manager._truncate_thread_at( + tid, "msg_user_2", keep_target=True + ) + kept_ids = [m["id"] for m in bound_manager._threads[tid]] + assert kept_ids == ["msg_user_1", "msg_asst_1", "msg_user_2"] + assert removed_ids == ["msg_asst_2"] + assert len(removed) == 1 + + def test_drop_target_removes_message_and_after(self, bound_manager): + tid, _ = _seed_thread(bound_manager) + removed, removed_ids = bound_manager._truncate_thread_at( + tid, "msg_user_2", keep_target=False + ) + kept_ids = [m["id"] for m in bound_manager._threads[tid]] + assert kept_ids == ["msg_user_1", "msg_asst_1"] + assert removed_ids == ["msg_user_2", "msg_asst_2"] + assert len(removed) == 2 + + def test_unknown_message_id_no_op(self, bound_manager): + tid, msgs = _seed_thread(bound_manager) + removed, removed_ids = bound_manager._truncate_thread_at(tid, "ghost", keep_target=True) + assert removed == [] + assert removed_ids == [] + assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + + +class TestEditMessage: + """Tests for _on_edit_message — replace text + truncate + regenerate.""" + + def test_edit_emits_messages_deleted(self, bound_manager, widget): + tid, _ = _seed_thread(bound_manager) + bound_manager._on_edit_message( + {"messageId": "msg_user_2", "threadId": tid, "text": "REVISED"}, + "chat:edit-message", + "", + ) + time.sleep(0.2) + deletions = widget.get_events("chat:messages-deleted") + assert deletions, "expected at least one chat:messages-deleted event" + d = deletions[0] + assert d["editedMessageId"] == "msg_user_2" + assert d["editedText"] == "REVISED" + assert d["messageIds"] == ["msg_asst_2"] + + def test_edit_replaces_user_message_text(self, bound_manager, widget): + tid, _ = _seed_thread(bound_manager) + bound_manager._on_edit_message( + {"messageId": "msg_user_2", "threadId": tid, "text": "REVISED"}, + "chat:edit-message", + "", + ) + time.sleep(0.2) + thread = bound_manager._threads[tid] + edited = next(m for m in thread if m.get("id") == "msg_user_2") + assert edited["text"] == "REVISED" + + def test_edit_unknown_message_is_noop(self, bound_manager, widget): + tid, msgs = _seed_thread(bound_manager) + bound_manager._on_edit_message( + {"messageId": "ghost", "threadId": tid, "text": "x"}, + "chat:edit-message", + "", + ) + assert widget.get_events("chat:messages-deleted") == [] + assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + + def test_edit_empty_text_is_noop(self, bound_manager, widget): + tid, _ = _seed_thread(bound_manager) + bound_manager._on_edit_message( + {"messageId": "msg_user_2", "threadId": tid, "text": " "}, + "chat:edit-message", + "", + ) + assert widget.get_events("chat:messages-deleted") == [] + + def test_edit_cancels_active_generation(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + tid = mgr.active_thread_id + cancel = threading.Event() + mgr._cancel_events[tid] = cancel + mgr._threads[tid] = [ + {"id": "msg_user_1", "role": "user", "text": "first"}, + {"id": "msg_asst_1", "role": "assistant", "text": "reply"}, + ] + mgr._on_edit_message( + {"messageId": "msg_user_1", "threadId": tid, "text": "REVISED"}, + "chat:edit-message", + "", + ) + time.sleep(0.2) + assert cancel.is_set() + + +class TestResendFrom: + """Tests for _on_resend_from — drop target + everything after, regenerate.""" + + def test_resend_keeps_target_and_drops_only_later_messages(self, bound_manager, widget): + tid, _ = _seed_thread(bound_manager) + bound_manager._on_resend_from( + {"messageId": "msg_user_2", "threadId": tid}, + "chat:resend-from", + "", + ) + time.sleep(0.2) + deletions = widget.get_events("chat:messages-deleted") + assert deletions + d = deletions[0] + # The target user message stays — only the assistant reply (and any + # subsequent turns) are dropped so "Resend" doesn't read as "your + # message was erased". + assert d["messageIds"] == ["msg_asst_2"] + assert "editedMessageId" not in d + assert "editedText" not in d + surviving_ids = [m["id"] for m in bound_manager._threads[tid]] + assert "msg_user_2" in surviving_ids + assert "msg_asst_2" not in surviving_ids + + def test_resend_re_runs_handler_with_same_text(self, bound_manager, widget): + tid, _ = _seed_thread(bound_manager) + bound_manager._on_resend_from( + {"messageId": "msg_user_2", "threadId": tid}, + "chat:resend-from", + "", + ) + time.sleep(0.3) + replies = widget.get_events("chat:assistant-message") + assert any("Echo: second question" in r.get("text", "") for r in replies) + + def test_resend_unknown_message_is_noop(self, bound_manager, widget): + tid, msgs = _seed_thread(bound_manager) + bound_manager._on_resend_from( + {"messageId": "ghost", "threadId": tid}, + "chat:resend-from", + "", + ) + assert widget.get_events("chat:messages-deleted") == [] + assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + + def test_resend_empty_message_id_returns(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + mgr._on_resend_from( + {"messageId": "", "threadId": mgr.active_thread_id}, + "chat:resend-from", + "", + ) + assert widget.get_events("chat:messages-deleted") == [] + + def test_resend_targeting_assistant_message_is_noop(self, bound_manager, widget): + """Only user messages can be resent; assistant ids are ignored.""" + tid, msgs = _seed_thread(bound_manager) + bound_manager._on_resend_from( + {"messageId": "msg_asst_1", "threadId": tid}, + "chat:resend-from", + "", + ) + assert widget.get_events("chat:messages-deleted") == [] + assert [m["id"] for m in bound_manager._threads[tid]] == [m["id"] for m in msgs] + + def test_resend_cancels_active_generation(self, widget): + mgr = ChatManager(handler=echo_handler) + mgr.bind(widget) + tid = mgr.active_thread_id + cancel = threading.Event() + mgr._cancel_events[tid] = cancel + mgr._threads[tid] = [ + {"id": "msg_user_1", "role": "user", "text": "first"}, + {"id": "msg_asst_1", "role": "assistant", "text": "reply"}, + ] + mgr._on_resend_from( + {"messageId": "msg_user_1", "threadId": tid}, + "chat:resend-from", + "", + ) + time.sleep(0.2) + assert cancel.is_set() + + +class TestUserMessageStoresId: + """The frontend-generated messageId must round-trip into thread storage.""" + + def test_user_message_uses_provided_id(self, bound_manager): + tid = bound_manager.active_thread_id + bound_manager._on_user_message( + {"messageId": "msg_provided_42", "text": "hi", "threadId": tid}, + "chat:user-message", + "", + ) + time.sleep(0.2) + first = bound_manager._threads[tid][0] + assert first["id"] == "msg_provided_42" + assert first["role"] == "user" + + def test_user_message_generates_id_if_absent(self, bound_manager): + tid = bound_manager.active_thread_id + bound_manager._on_user_message( + {"text": "hi", "threadId": tid}, + "chat:user-message", + "", + ) + time.sleep(0.2) + first = bound_manager._threads[tid][0] + assert first["id"].startswith("msg_") + + def test_assistant_message_carries_id(self, bound_manager): + tid = bound_manager.active_thread_id + bound_manager._on_user_message( + {"text": "hi", "threadId": tid}, + "chat:user-message", + "", + ) + time.sleep(0.3) + msgs = bound_manager._threads[tid] + asst = [m for m in msgs if m.get("role") == "assistant"] + assert asst + assert asst[0]["id"].startswith("msg_") + + +# ============================================================================= +# Provider integration +# ============================================================================= + + +class TestProviderRun: + """Test the ACP-provider execution path through _run_provider.""" + + def test_provider_exception_emits_error_message(self, widget): + class _Provider(_MinimalAsyncProvider): + async def prompt(self, _sid, _content, _cancel_event=None): + raise RuntimeError("provider boom") + if False: + yield # pragma: no cover + + mgr = ChatManager(provider=_Provider()) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.5) + msgs = widget.get_events("chat:assistant-message") + assert any("Error: provider boom" in m.get("text", "") for m in msgs) + + async def test_provider_no_updates_still_clears_typing_indicator(self, widget): + """When the provider yields zero updates, the after-loop typing-off + emit still runs.""" + + class _Provider(_MinimalAsyncProvider): + async def prompt(self, _sid, _content, _cancel_event=None): + # Yield no updates — typing_hidden stays False inside _run_provider + if False: + yield # pragma: no cover + + mgr = ChatManager(provider=_Provider()) + mgr.bind(widget) + mgr._on_user_message( + {"text": "hi", "threadId": mgr.active_thread_id}, + "chat:user-message", + "", + ) + time.sleep(0.4) + events = widget.get_events("chat:typing-indicator") + offs = [e for e in events if e.get("typing") is False] + assert offs + + +class TestTruncateProviderState: + """_truncate_provider_state forwards to provider.truncate_session.""" + + def test_no_provider_no_op(self): + mgr = ChatManager(handler=echo_handler) + mgr._truncate_provider_state("t", []) # silent + + def test_calls_provider_truncate_session(self): + called: list = [] + + class _Provider(_MinimalAsyncProvider): + def truncate_session(self, sid, kept): + called.append((sid, kept)) + + mgr = ChatManager(provider=_Provider()) + mgr._truncate_provider_state("thread-1", [{"id": "x"}]) + assert called == [("thread-1", [{"id": "x"}])] + + def test_provider_without_truncate_session_no_op(self): + mgr = ChatManager(provider=_MinimalAsyncProvider()) + mgr._truncate_provider_state("t", []) + + def test_provider_truncate_exception_swallowed(self): + class _Provider(_MinimalAsyncProvider): + def truncate_session(self, _sid, _kept): + raise RuntimeError("boom") + + mgr = ChatManager(provider=_Provider()) + mgr._truncate_provider_state("t", []) + + +class TestResendWithProvider: + def test_resend_with_provider_runs_provider_path(self, widget): + """The resend dispatch picks the provider path when ``_provider`` is set.""" + + class _Provider(_MinimalAsyncProvider): + async def prompt(self, _sid, _content, _cancel_event=None): + yield AgentMessageUpdate(text="provider-reply") + + mgr = ChatManager(provider=_Provider()) + mgr.bind(widget) + tid = mgr.active_thread_id + mgr._threads[tid] = [ + {"id": "msg_user_1", "role": "user", "text": "first"}, + {"id": "msg_asst_1", "role": "assistant", "text": "reply"}, + ] + mgr._on_resend_from( + {"messageId": "msg_user_1", "threadId": tid}, + "chat:resend-from", + "", + ) + time.sleep(0.5) + assert any( + "provider-reply" in c.get("chunk", "") for c in widget.get_events("chat:stream-chunk") + ) diff --git a/pywry/tests/test_cli.py b/pywry/tests/test_cli.py index 308f9b7..dbee57f 100644 --- a/pywry/tests/test_cli.py +++ b/pywry/tests/test_cli.py @@ -400,3 +400,427 @@ def test_shows_precedence_note(self): output = mock_stdout.getvalue() assert "override" in output.lower() or "precedence" in output.lower() + + +class TestHandlePluginPath: + """Tests for the plugin-path command.""" + + def test_returns_plugin_root(self, tmp_path, monkeypatch): + # Mock pywry.__file__ and Path.exists to simulate plugin installed. + import pywry + from pywry.cli import handle_plugin_path + + # Build a fake plugin directory tree + fake_pkg = tmp_path / "pywry" + plugin_root = fake_pkg / "_claude_plugin" / ".claude-plugin" + plugin_root.mkdir(parents=True) + (plugin_root / "marketplace.json").write_text("{}") + (plugin_root / "plugin.json").write_text("{}") + + # Patch pywry.__file__ so resolve points to fake_pkg/__init__.py + fake_init = fake_pkg / "__init__.py" + fake_init.write_text("") + monkeypatch.setattr(pywry, "__file__", str(fake_init)) + + args = argparse.Namespace(check=False, marketplace=False) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_plugin_path(args) + assert result == 0 + assert "_claude_plugin" in mock_stdout.getvalue() + + def test_marketplace_flag_returns_marketplace_path(self, tmp_path, monkeypatch): + import pywry + from pywry.cli import handle_plugin_path + + fake_pkg = tmp_path / "pywry" + plugin_root = fake_pkg / "_claude_plugin" / ".claude-plugin" + plugin_root.mkdir(parents=True) + (plugin_root / "marketplace.json").write_text("{}") + (plugin_root / "plugin.json").write_text("{}") + fake_init = fake_pkg / "__init__.py" + fake_init.write_text("") + monkeypatch.setattr(pywry, "__file__", str(fake_init)) + + args = argparse.Namespace(check=False, marketplace=True) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_plugin_path(args) + assert result == 0 + assert "marketplace.json" in mock_stdout.getvalue() + + def test_check_missing_returns_error(self, tmp_path, monkeypatch): + import pywry + from pywry.cli import handle_plugin_path + + fake_pkg = tmp_path / "pywry" + fake_pkg.mkdir(parents=True) + fake_init = fake_pkg / "__init__.py" + fake_init.write_text("") + monkeypatch.setattr(pywry, "__file__", str(fake_init)) + + args = argparse.Namespace(check=True, marketplace=False) + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + result = handle_plugin_path(args) + assert result == 1 + assert "not found" in mock_stderr.getvalue().lower() + + +class TestHandleConfigOutputFile: + """Tests for the --output option of config command.""" + + def test_writes_to_file(self, tmp_path): + out = tmp_path / "out.toml" + args = argparse.Namespace( + show=False, toml=True, env=False, sources=False, output=str(out) + ) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_config(args) + assert result == 0 + assert out.exists() + assert "written" in mock_stdout.getvalue().lower() + + +class TestHandleConfigEnvOption: + def test_env_output_to_stdout(self): + args = argparse.Namespace( + show=False, toml=False, env=True, sources=False, output=None + ) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_config(args) + assert result == 0 + out = mock_stdout.getvalue() + # to_env output should contain PYWRY_-prefixed vars or be empty if no overrides + assert isinstance(out, str) + + +class TestHandleMcp: + def test_import_error_path(self): + import builtins + + from pywry.cli import handle_mcp + + real_import = builtins.__import__ + + # Save and remove cached pywry.mcp so the inner re-import goes through __import__ + cached_pkg = sys.modules.pop("pywry.mcp", None) + + def fake_import(name, *a, **k): + # Relative import inside pywry.cli: `from .mcp import run_server` + # passes name="mcp" (or "pywry.mcp" for absolute). + if name == "mcp" or name == "pywry.mcp": + raise ImportError("no mcp") + return real_import(name, *a, **k) + + args = argparse.Namespace( + transport="stdio", + port=None, + host=None, + name=None, + headless=False, + native=False, + ) + try: + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + with patch.object(builtins, "__import__", side_effect=fake_import): + result = handle_mcp(args) + assert result == 1 + assert "MCP SDK" in mock_stderr.getvalue() or "mcp" in mock_stderr.getvalue().lower() + finally: + if cached_pkg is not None: + sys.modules["pywry.mcp"] = cached_pkg + + def test_runs_server_successfully(self): + from pywry.cli import handle_mcp + + with patch("pywry.mcp.run_server") as mock_run: + mock_run.return_value = None + args = argparse.Namespace( + transport=None, + port=None, + host=None, + name=None, + headless=False, + native=False, + ) + result = handle_mcp(args) + assert result == 0 + mock_run.assert_called_once() + + def test_keyboard_interrupt(self): + from pywry.cli import handle_mcp + + with patch("pywry.mcp.run_server", side_effect=KeyboardInterrupt): + args = argparse.Namespace( + transport=None, + port=None, + host=None, + name=None, + headless=False, + native=False, + ) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_mcp(args) + assert result == 0 + assert "stopped" in mock_stdout.getvalue().lower() + + def test_server_error(self): + from pywry.cli import handle_mcp + + with patch("pywry.mcp.run_server", side_effect=RuntimeError("boom")): + args = argparse.Namespace( + transport=None, + port=None, + host=None, + name=None, + headless=False, + native=False, + ) + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + result = handle_mcp(args) + assert result == 1 + assert "boom" in mock_stderr.getvalue() or "Error" in mock_stderr.getvalue() + + def test_native_flag_disables_headless(self): + from pywry.cli import handle_mcp + + with patch("pywry.mcp.run_server") as mock_run: + args = argparse.Namespace( + transport=None, + port=None, + host=None, + name="my-server", + headless=False, + native=True, + ) + handle_mcp(args) + # native=True → headless=False + call_kwargs = mock_run.call_args.kwargs + assert call_kwargs.get("headless") is False + + def test_headless_flag(self): + from pywry.cli import handle_mcp + + with patch("pywry.mcp.run_server") as mock_run: + args = argparse.Namespace( + transport=None, + port=None, + host=None, + name=None, + headless=True, + native=False, + ) + handle_mcp(args) + call_kwargs = mock_run.call_args.kwargs + assert call_kwargs.get("headless") is True + + def test_no_flags_uses_config_headless(self, monkeypatch): + """When --headless, --native, and PYWRY_HEADLESS env are all unset, use config.""" + from pywry.cli import handle_mcp + + monkeypatch.delenv("PYWRY_HEADLESS", raising=False) + with patch("pywry.mcp.run_server") as mock_run: + args = argparse.Namespace( + transport=None, + port=None, + host=None, + name=None, + headless=False, + native=False, + ) + handle_mcp(args) + # config default headless is used + assert "headless" in mock_run.call_args.kwargs + + +class TestHandleConfigDefaults: + def test_show_defaults_when_no_flags(self): + """When no toml/env/show flag passed, output defaults to show format.""" + args = argparse.Namespace( + show=False, toml=False, env=False, sources=False, output=None + ) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_config(args) + assert result == 0 + # Default falls through to show format + assert "[csp]" in mock_stdout.getvalue() + + +class TestShowConfigSourcesEnvNotSet: + def test_empty_env_vars(self, monkeypatch): + """Env vars section shows ✗ No vars when no PYWRY_* present.""" + # Remove all PYWRY_ env vars + for key in list(__import__("os").environ.keys()): + if key.startswith("PYWRY_"): + monkeypatch.delenv(key, raising=False) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + show_config_sources() + out = mock_stdout.getvalue() + assert "No vars" in out or "No" in out + + def test_more_than_three_env_vars_truncates(self, monkeypatch): + """Env vars section shows ... when more than 3 PYWRY_ vars exist.""" + for i in range(5): + monkeypatch.setenv(f"PYWRY_FAKE_{i}", "1") + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + show_config_sources() + out = mock_stdout.getvalue() + assert "..." in out + + +class TestHandleInstallSkillsImportError: + def test_import_error_path(self): + """When the mcp.install module isn't importable, return 1.""" + import builtins + + from pywry.cli import handle_install_skills + + real_import = builtins.__import__ + + # Save and remove cached module so the re-import goes through __import__. + cached = sys.modules.pop("pywry.mcp.install", None) + + def fake_import(name, *a, **k): + # Relative import inside pywry.cli passes name="mcp.install". + if "mcp.install" in name: + raise ImportError("no install") + return real_import(name, *a, **k) + + args = argparse.Namespace( + list=False, + list_targets=False, + target=None, + skills=None, + custom_dir=None, + overwrite=False, + dry_run=False, + verbose=False, + ) + try: + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + with patch.object(builtins, "__import__", side_effect=fake_import): + result = handle_install_skills(args) + assert result == 1 + assert "MCP module" in mock_stderr.getvalue() + finally: + if cached is not None: + sys.modules["pywry.mcp.install"] = cached + + +class TestHandleInstallSkills: + def test_list_flag_prints_skills(self): + from pywry.cli import handle_install_skills + + args = argparse.Namespace( + list=True, + list_targets=False, + target=None, + skills=None, + custom_dir=None, + overwrite=False, + dry_run=False, + verbose=False, + ) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_install_skills(args) + assert result == 0 + assert "skills" in mock_stdout.getvalue().lower() + + def test_list_targets_flag(self): + from pywry.cli import handle_install_skills + + args = argparse.Namespace( + list=False, + list_targets=True, + target=None, + skills=None, + custom_dir=None, + overwrite=False, + dry_run=False, + verbose=False, + ) + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_install_skills(args) + assert result == 0 + assert "target" in mock_stdout.getvalue().lower() + + def test_install_with_targets(self): + from pywry.cli import handle_install_skills + + args = argparse.Namespace( + list=False, + list_targets=False, + target=["claude_code"], + skills=None, + custom_dir=None, + overwrite=False, + dry_run=True, + verbose=False, + ) + with patch("pywry.mcp.install.install_skills") as mock_install: + mock_install.return_value = {} + with patch("pywry.mcp.install.print_install_results"): + with patch("sys.stdout", new_callable=StringIO) as mock_stdout: + result = handle_install_skills(args) + assert result == 0 + assert "DRY RUN" in mock_stdout.getvalue() + + def test_install_value_error(self): + from pywry.cli import handle_install_skills + + args = argparse.Namespace( + list=False, + list_targets=False, + target=["invalid"], + skills=None, + custom_dir=None, + overwrite=False, + dry_run=False, + verbose=False, + ) + with patch("pywry.mcp.install.install_skills", side_effect=ValueError("bad")): + with patch("sys.stderr", new_callable=StringIO) as mock_stderr: + result = handle_install_skills(args) + assert result == 1 + assert "Error" in mock_stderr.getvalue() + + +class TestMainEntryPointDispatch: + """Cover the if/elif branches in main() for mcp/install-skills/plugin-path.""" + + def test_mcp_command_dispatches(self): + with patch("pywry.cli.handle_mcp", return_value=0) as mock_handle: + with patch.object(sys, "argv", ["pywry", "mcp"]): + result = main() + assert result == 0 + mock_handle.assert_called_once() + + def test_install_skills_command_dispatches(self): + with patch( + "pywry.cli.handle_install_skills", return_value=0 + ) as mock_handle: + with patch.object(sys, "argv", ["pywry", "install-skills", "--list"]): + result = main() + assert result == 0 + mock_handle.assert_called_once() + + def test_plugin_path_command_dispatches(self): + with patch( + "pywry.cli.handle_plugin_path", return_value=0 + ) as mock_handle: + with patch.object(sys, "argv", ["pywry", "plugin-path"]): + result = main() + assert result == 0 + mock_handle.assert_called_once() + + +class TestMainModuleEntry: + """Cover the `if __name__ == "__main__"` guard.""" + + def test_module_main(self): + import runpy + + with patch("pywry.cli.main", return_value=0): + with patch.object(sys, "argv", ["pywry"]): + try: + runpy.run_module("pywry.cli", run_name="__main__") + except SystemExit as e: + assert e.code == 0 diff --git a/pywry/tests/test_config.py b/pywry/tests/test_config.py index a9c95e0..5474182 100644 --- a/pywry/tests/test_config.py +++ b/pywry/tests/test_config.py @@ -4,6 +4,8 @@ Including CSP (Content Security Policy) configuration and meta tag generation. """ +from unittest.mock import patch + import pytest from pywry.config import ( @@ -656,3 +658,348 @@ def test_env_var_override(self, monkeypatch): def test_available_plugins_has_19(self): """Registry contains exactly 19 plugins.""" assert len(AVAILABLE_TAURI_PLUGINS) == 19 + + +# ───────────────────────────────────────────────────────────────────────────── +# Coverage gaps: helper functions, validators, edge paths +# ───────────────────────────────────────────────────────────────────────────── + + +class TestFindConfigFiles: + """Cover the _find_config_files helper.""" + + def test_finds_pyproject_toml(self, tmp_path, monkeypatch): + from pywry.config import _find_config_files + + monkeypatch.chdir(tmp_path) + (tmp_path / "pyproject.toml").write_text("[tool.pywry]\n") + result = _find_config_files() + assert any("pyproject.toml" in str(p) for p in result) + + def test_finds_pywry_toml(self, tmp_path, monkeypatch): + from pywry.config import _find_config_files + + monkeypatch.chdir(tmp_path) + (tmp_path / "pywry.toml").write_text("") + result = _find_config_files() + assert any("pywry.toml" in str(p) for p in result) + + def test_finds_env_config_file(self, tmp_path, monkeypatch): + from pywry.config import _find_config_files + + monkeypatch.chdir(tmp_path) + env_path = tmp_path / "custom.toml" + env_path.write_text("") + monkeypatch.setenv("PYWRY_CONFIG_FILE", str(env_path)) + result = _find_config_files() + assert any("custom.toml" in str(p) for p in result) + + def test_skips_missing_env_config(self, tmp_path, monkeypatch): + from pywry.config import _find_config_files + + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("PYWRY_CONFIG_FILE", str(tmp_path / "nonexistent.toml")) + result = _find_config_files() + # No file added + assert not any("nonexistent.toml" in str(p) for p in result) + + def test_linux_user_config_path(self, tmp_path, monkeypatch): + """Trigger line 55: linux user config path branch.""" + import pywry.config as cfg + + monkeypatch.chdir(tmp_path) + with patch.object(cfg.sys, "platform", "linux"): + result = cfg._find_config_files() + # Even on linux path, the function still runs and returns a list + assert isinstance(result, list) + + def test_finds_user_config_file(self, tmp_path, monkeypatch): + """Trigger line 58: user_config exists branch (Windows APPDATA path).""" + import pywry.config as cfg + + # Create a fake APPDATA dir with the config file + appdata = tmp_path / "appdata" + config_dir = appdata / "pywry" + config_dir.mkdir(parents=True) + config_file = config_dir / "config.toml" + config_file.write_text("[csp]\n") + + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("APPDATA", str(appdata)) + + # Patch sys.platform to win32 to ensure that branch runs + with patch.object(cfg.sys, "platform", "win32"): + result = cfg._find_config_files() + assert any("config.toml" in str(p) for p in result) + + +class TestLoadTomlConfig: + """Cover _load_toml_config error paths.""" + + def test_handles_decode_error(self, tmp_path, monkeypatch): + from pywry.config import _load_toml_config + + monkeypatch.chdir(tmp_path) + (tmp_path / "pyproject.toml").write_text("[invalid toml [[[") + result = _load_toml_config() + assert isinstance(result, dict) + + def test_returns_empty_when_tomllib_none(self, monkeypatch): + import pywry.config as cfg + + with patch.object(cfg, "tomllib", None): + result = cfg._load_toml_config() + assert result == {} + + +class TestDeepMergeOverride: + """Cover the non-dict override branch of _deep_merge.""" + + def test_override_replaces_non_dict(self): + from pywry.config import _deep_merge + + # base[key] is a non-dict, override[key] is a non-dict → replace + result = _deep_merge({"x": 1, "y": [1, 2]}, {"x": 2, "y": [3, 4]}) + assert result == {"x": 2, "y": [3, 4]} + + def test_override_replaces_dict_with_value(self): + from pywry.config import _deep_merge + + result = _deep_merge({"x": {"a": 1}}, {"x": "literal"}) + assert result == {"x": "literal"} + + def test_recursive_merge_of_dicts(self): + from pywry.config import _deep_merge + + # Triggers line 100 (recursive call) - both keys are dicts + result = _deep_merge( + {"section": {"a": 1, "b": 2}}, + {"section": {"b": 3, "c": 4}}, + ) + assert result == {"section": {"a": 1, "b": 3, "c": 4}} + + +class TestSecuritySettingsLocalhostWithPorts: + def test_with_specific_ports(self): + from pywry.config import SecuritySettings + + settings = SecuritySettings.localhost(ports=[8080, 9000]) + assert "8080" in settings.connect_src + assert "9000" in settings.connect_src + + +class TestTVChartStorageValidators: + """Cover storage identifier/path validation paths.""" + + def test_empty_string_returns_empty(self): + from pywry.config import TVChartSettings + + s = TVChartSettings(storage_namespace="") + assert s.storage_namespace == "" + + def test_too_long_raises(self): + from pywry.config import TVChartSettings + + with pytest.raises(Exception, match="<= 512"): + TVChartSettings(storage_namespace="x" * 513) + + def test_control_chars_raises(self): + from pywry.config import TVChartSettings + + with pytest.raises(Exception, match="control characters"): + TVChartSettings(storage_namespace="bad\x01char") + + def test_unsupported_chars_raises(self): + from pywry.config import TVChartSettings + + with pytest.raises(Exception, match="unsupported"): + TVChartSettings(storage_namespace="bad@char") + + def test_path_too_long_raises(self): + from pywry.config import TVChartSettings + + with pytest.raises(Exception, match="<= 512"): + TVChartSettings(storage_path="x" * 513) + + def test_path_control_chars_raises(self): + from pywry.config import TVChartSettings + + with pytest.raises(Exception, match="control"): + TVChartSettings(storage_path="bad\x01") + + def test_path_unsupported_chars_raises(self): + from pywry.config import TVChartSettings + + with pytest.raises(Exception, match="unsupported"): + TVChartSettings(storage_path="bad@char") + + def test_storage_path_empty_returns_empty(self): + from pywry.config import TVChartSettings + + # Empty string triggers the early-return branch (line 522) + s = TVChartSettings(storage_path="") + assert s.storage_path == "" + + +class TestCommaSeparatedValidators: + def test_watch_directories_string(self): + from pywry.config import HotReloadSettings + + s = HotReloadSettings(watch_directories="a, b, c") + assert s.watch_directories == ["a", "b", "c"] + + def test_watch_directories_empty_string(self): + from pywry.config import HotReloadSettings + + s = HotReloadSettings(watch_directories="") + assert s.watch_directories == [] + + def test_default_roles_string(self): + from pywry.config import DeploySettings + + s = DeploySettings(default_roles="admin, user") + assert "admin" in s.default_roles + + def test_admin_users_string(self): + from pywry.config import DeploySettings + + s = DeploySettings(admin_users="alice, bob") + assert "alice" in s.admin_users + + def test_admin_users_none(self): + from pywry.config import DeploySettings + + s = DeploySettings(admin_users=None) + assert s.admin_users == [] + + def test_public_paths_string(self): + from pywry.config import DeploySettings + + s = DeploySettings(auth_public_paths="/health, /metrics") + assert "/health" in s.auth_public_paths + + def test_server_cors_origins_string(self): + """ServerSettings parses cors_origins from comma-separated string.""" + from pywry.config import ServerSettings + + s = ServerSettings(cors_origins="https://example.com, https://other.com") + assert "https://example.com" in s.cors_origins + + def test_server_cors_methods_string(self): + from pywry.config import ServerSettings + + s = ServerSettings(cors_allow_methods="GET, POST") + assert "GET" in s.cors_allow_methods + + +class TestOAuth2ValidateCustomProvider: + def test_returns_value(self): + from pywry.config import OAuth2Settings + + s = OAuth2Settings(client_id="abc") + assert s.client_id == "abc" + + +class TestTauriPluginsTypeErrors: + def test_invalid_tauri_plugins_type(self): + with pytest.raises(TypeError, match="must be a list"): + PyWrySettings(tauri_plugins=42) + + def test_invalid_extra_capabilities_type(self): + with pytest.raises(TypeError, match="must be a list"): + PyWrySettings(extra_capabilities=42) + + +class TestOAuth2AutoDetection: + def test_auto_oauth2_from_env(self, monkeypatch): + monkeypatch.setenv("PYWRY_OAUTH2__CLIENT_ID", "test-client") + from pywry.config import PyWrySettings + + settings = PyWrySettings() + assert settings.oauth2 is not None + + +class TestToTomlOAuth2Section: + def test_oauth2_included_when_set(self): + from pywry.config import OAuth2Settings, PyWrySettings + + oauth2 = OAuth2Settings(client_id="x", provider="google") + settings = PyWrySettings(oauth2=oauth2) + toml = settings.to_toml() + assert "[oauth2]" in toml + + def test_extra_capabilities_included(self): + settings = PyWrySettings(extra_capabilities=["read-file"]) + toml = settings.to_toml() + assert "extra_capabilities" in toml + + def test_extra_capabilities_in_env(self): + """Trigger line 1338: extra_capabilities export in to_env().""" + settings = PyWrySettings(extra_capabilities=["read-file"]) + env = settings.to_env() + assert "PYWRY_EXTRA_CAPABILITIES" in env + assert "read-file" in env + + +class TestToEnvOAuth2Section: + def test_oauth2_env_export(self): + from pywry.config import OAuth2Settings, PyWrySettings + + oauth2 = OAuth2Settings(client_id="x", provider="google") + settings = PyWrySettings(oauth2=oauth2) + env = settings.to_env() + # OAuth2 may not appear in env-export, but should run without error + assert isinstance(env, str) + assert "PYWRY_" in env + + +class TestSettingsCacheManagement: + def test_get_settings_cached(self): + from pywry.config import clear_settings, get_settings + + clear_settings() + a = get_settings() + b = get_settings() + assert a is b + + def test_clear_settings_resets(self): + from pywry.config import clear_settings, get_settings + + a = get_settings() + clear_settings() + b = get_settings() + # New instance after clear + assert isinstance(b, type(a)) + + def test_reload_settings(self): + from pywry.config import reload_settings + + s = reload_settings() + assert s is not None + + +class TestInvalidSecuritySettings: + """Tests for invalid SecuritySettings values.""" + + def test_empty_default_src_is_allowed(self) -> None: + """Empty default_src is allowed (permissive).""" + # Empty string is technically valid, just insecure + settings = SecuritySettings(default_src="") + assert settings.default_src == "" + + def test_none_default_src_uses_default(self) -> None: + """None default_src falls back to default value.""" + settings = SecuritySettings() + assert settings.default_src != "" + + +class TestInvalidAssetSettings: + """Tests for invalid AssetSettings values.""" + + def test_invalid_plotly_version_format(self) -> None: + """Invalid plotly version format is accepted (no validation).""" + # Version strings aren't validated, user's responsibility + settings = AssetSettings(plotly_version="not-a-version") + assert settings.plotly_version == "not-a-version" + + diff --git a/pywry/tests/test_deepagent_provider.py b/pywry/tests/test_deepagent_provider.py index 732f6b8..9505805 100644 --- a/pywry/tests/test_deepagent_provider.py +++ b/pywry/tests/test_deepagent_provider.py @@ -1,46 +1,141 @@ -"""Tests for the DeepAgentProvider. +"""Tests for ``pywry.chat.providers.deepagent``. -Uses a mock CompiledGraph that yields known astream_events -to verify the provider maps LangGraph events to ACP SessionUpdate types. +Uses a fake CompiledGraph that yields scripted ``astream_events`` to verify +the provider maps LangGraph events to ACP SessionUpdate types, plus direct +tests for the module-level helpers (text filter, todo extraction, inline +tool-call rewriter, etc.) and the truncate-session behaviour. """ from __future__ import annotations import asyncio +import builtins +import sys + +from typing import Any +from unittest.mock import MagicMock import pytest from pywry.chat.models import TextPart -from pywry.chat.providers.deepagent import DeepagentProvider, _map_tool_kind +from pywry.chat.providers import deepagent as da +from pywry.chat.providers.deepagent import ( + DeepagentProvider, + _coerce_text, + _coerce_todo_list, + _consume_one_inline_call, + _extract_answer_from_content, + _extract_stream_text, + _extract_thinking_from_chunk, + _extract_todos_from_tool_output, + _flatten_message_content, + _is_root_chain_end, + _map_todo_status, + _map_tool_kind, + _next_pending_plan_step, + _parse_inline_tool_calls, + _rewrite_inline_tool_call_message, + _rewrite_response_messages, + _scan_balanced_braces, + _stream_part_text, + _strip_special_tokens, + _ToolCallTextFilter, + _try_parse_call_args, +) from pywry.chat.session import ClientCapabilities from pywry.chat.updates import ( AgentMessageUpdate, PlanUpdate, StatusUpdate, + ThinkingUpdate, ToolCallUpdate, ) -class FakeChunk: - def __init__(self, content: str = ""): - self.content = content +# ============================================================================= +# Module-level fixtures / helpers +# ============================================================================= -def make_event(event: str, name: str = "", data: dict | None = None, run_id: str = "r1"): - return {"event": event, "name": name, "data": data or {}, "run_id": run_id} +class FakeChunk: + def __init__(self, content: str = "", additional_kwargs: dict | None = None): + self.content = content + self.additional_kwargs = additional_kwargs or {} -async def fake_stream_events(events: list[dict]): - for e in events: - yield e +def make_event(event: str, name: str = "", data: dict | None = None, run_id: str = "r1", **extra): + return { + "event": event, + "name": name, + "data": data or {}, + "run_id": run_id, + **extra, + } class FakeAgent: + """Yields a scripted list of events from ``astream_events``.""" + def __init__(self, events: list[dict]): self._events = events - def astream_events(self, input_data: dict, config: dict, version: str = "v2"): - return fake_stream_events(self._events) + def astream_events(self, _input_data: dict, config: dict = None, version: str = "v2"): + async def _gen(): + for e in self._events: + yield e + + return _gen() + + +async def _drain_prompt(provider: DeepagentProvider, sid: str, text: str = "hi"): + out = [] + async for u in provider.prompt(sid, [TextPart(text=text)]): + out.append(u) + return out + + +class AIMessage: + """Duck-typed AIMessage. + + The middleware checks ``msg.__class__.__name__ == 'AIMessage'`` and the + rewriter reconstructs the message by calling the class as a constructor, + so this needs to accept the same kwargs as LangChain's AIMessage. + """ + + def __init__( + self, + content=None, + tool_calls=None, + id=None, + response_metadata=None, + additional_kwargs=None, + ): + self.content = content + self.tool_calls = tool_calls + self.id = id or "msg-id" + self.response_metadata = response_metadata or {} + self.additional_kwargs = additional_kwargs or {} + + +class _NotAIMessage: + """Any class with a name other than ``AIMessage`` exits the rewriter early.""" + + +@pytest.fixture +def provider_factory(): + """Build a DeepagentProvider with autosetup disabled and an optional agent.""" + + def _make(agent=None, **kwargs): + defaults = {"auto_checkpointer": False, "auto_store": False} + defaults.update(kwargs) + return DeepagentProvider(agent=agent, **defaults) + + return _make + + +# ============================================================================= +# Tool-kind / todo-status mapping +# ============================================================================= class TestToolKindMapping: @@ -60,10 +155,566 @@ def test_unknown_tool(self): assert _map_tool_kind("my_custom_tool") == "other" +class TestMapTodoStatus: + def test_known_statuses(self): + assert _map_todo_status("todo") == "pending" + assert _map_todo_status("in_progress") == "in_progress" + assert _map_todo_status("in-progress") == "in_progress" + assert _map_todo_status("done") == "completed" + assert _map_todo_status("completed") == "completed" + + def test_unknown_falls_back_to_pending(self): + assert _map_todo_status("foo") == "pending" + + +# ============================================================================= +# _coerce_text and stream-text helpers +# ============================================================================= + + +class TestCoerceText: + def test_none(self): + assert _coerce_text(None) == "" + + def test_string(self): + assert _coerce_text("hello") == "hello" + + def test_list_of_strings(self): + assert _coerce_text(["a", "b"]) == "ab" + + def test_list_of_dicts_text_key(self): + assert _coerce_text([{"text": "a"}, {"text": "b"}]) == "ab" + + def test_list_of_dicts_content_fallback(self): + assert _coerce_text([{"content": "x"}]) == "x" + + def test_list_skips_non_string_text(self): + assert _coerce_text([{"text": 42}]) == "" + + def test_other_type_str(self): + assert _coerce_text(123) == "123" + + +class TestExtractAnswerFromContent: + def test_string_content(self): + assert _extract_answer_from_content("hello") == "hello" + + def test_list_skips_thinking_parts(self): + content = [ + {"type": "thinking", "text": "skip-me"}, + {"type": "text", "text": "keep-me"}, + ] + assert _extract_answer_from_content(content) == "keep-me" + + def test_list_skips_tool_call_parts(self): + content = [ + {"type": "tool_use", "text": "skip-me"}, + {"type": "text", "text": "keep-me"}, + ] + assert _extract_answer_from_content(content) == "keep-me" + + def test_list_skips_non_dict_parts(self): + assert _extract_answer_from_content([42, {"type": "text", "text": "x"}]) == "x" + + +class TestExtractThinkingFromChunk: + def test_metadata_reasoning_content(self): + chunk = MagicMock() + chunk.additional_kwargs = {"reasoning_content": "internal"} + assert _extract_thinking_from_chunk(chunk, "") == "internal" + + def test_metadata_reasoning_fallback(self): + chunk = MagicMock() + chunk.additional_kwargs = {"reasoning": "internal-r"} + assert _extract_thinking_from_chunk(chunk, "") == "internal-r" + + def test_metadata_thinking_fallback(self): + chunk = MagicMock() + chunk.additional_kwargs = {"thinking": "internal-t"} + assert _extract_thinking_from_chunk(chunk, "") == "internal-t" + + def test_falls_back_to_attribute(self): + class _Chunk: + additional_kwargs: dict = {} + reasoning_content = "attr" + + assert _extract_thinking_from_chunk(_Chunk(), "") == "attr" + + def test_appends_thinking_parts_from_content_list(self): + chunk = MagicMock() + chunk.additional_kwargs = {} + chunk.reasoning_content = "" + content = [ + {"type": "thinking", "text": "more"}, + {"type": "text", "text": "ignored"}, + ] + assert "more" in _extract_thinking_from_chunk(chunk, content) + + def test_handles_non_dict_additional_kwargs(self): + class _Chunk: + additional_kwargs = ["not", "a", "dict"] + reasoning_content = "fallback-attr" + + assert _extract_thinking_from_chunk(_Chunk(), "") == "fallback-attr" + + +class TestExtractStreamText: + def test_none_chunk(self): + assert _extract_stream_text(None) == ("", "") + + def test_full_chunk(self): + chunk = MagicMock() + chunk.content = [{"type": "text", "text": "answer"}] + chunk.additional_kwargs = {"reasoning_content": "thought"} + thinking, answer = _extract_stream_text(chunk) + assert thinking == "thought" + assert answer == "answer" + + +class TestStreamPartText: + def test_non_dict_returns_empty(self): + assert _stream_part_text("not a dict") == "" + assert _stream_part_text(None) == "" + assert _stream_part_text(42) == "" + + +class TestIsRootChainEnd: + def test_yes(self): + assert _is_root_chain_end({"event": "on_chain_end", "parent_ids": []}) is True + + def test_not_chain_end(self): + assert _is_root_chain_end({"event": "other"}) is False + + def test_has_parents(self): + assert _is_root_chain_end({"event": "on_chain_end", "parent_ids": ["p1"]}) is False + + +# ============================================================================= +# _ToolCallTextFilter — drives real text-stream behaviour +# ============================================================================= + + +class TestToolCallTextFilter: + """End-to-end behaviour of the leaked-tool-call stream filter.""" + + def test_empty_feed_returns_empty(self): + f = _ToolCallTextFilter() + assert f.feed("") == "" + + def test_passes_through_normal_text(self): + f = _ToolCallTextFilter() + assert f.feed("hello world. ") == "hello world. " + + def test_strips_complete_functions_call(self): + f = _ToolCallTextFilter() + out = f.feed('before functions.foo:1{"a": 1} after') + assert "functions.foo" not in out + assert "before" in out + assert "after" in (out + f.flush()) + + def test_strips_special_token(self): + f = _ToolCallTextFilter() + out = f.feed("hello <|tool_call_end|> world") + flushed = f.flush() + full = out + flushed + assert "tool_call_end" not in full + assert "hello" in full + assert "world" in full + + def test_buffers_partial_marker_across_chunks(self): + """The ``functions.`` prefix splits across chunks — filter must stay + stateful and never emit the prefix as plain text.""" + f = _ToolCallTextFilter() + out1 = f.feed("function") + assert "function" not in out1 # held back as unsafe-prefix tail + out2 = f.feed("s.") + out3 = f.feed('foo{"a": 1}') + flushed = f.flush() + full = out1 + out2 + out3 + flushed + assert "functions." not in full + + def test_unterminated_call_block_dropped_on_flush(self): + f = _ToolCallTextFilter() + f.feed('hello functions.foo{"a": 1') # no closing brace + assert f.flush() == "" # buffer dropped on flush + + def test_unterminated_special_token_dropped_on_flush(self): + f = _ToolCallTextFilter() + f.feed("hello <|im_start") # no closing |> + assert f.flush() == "" + + def test_flush_returns_remaining_safe_buffer(self): + f = _ToolCallTextFilter() + f.feed("ok ") + assert f.flush() == "" + + # Bytes whose tail is an unsafe-prefix of "functions." stay buffered + # until flush — feed() emits the safe prefix, flush() returns the tail. + f2 = _ToolCallTextFilter() + emitted = f2.feed("ok funct") + flushed = f2.flush() + assert emitted == "ok " + assert flushed == "funct" + + def test_string_with_braces_inside_args(self): + """The brace counter must be string-literal aware: ``"}"`` in the + JSON args must not pop the depth.""" + f = _ToolCallTextFilter() + out = f.feed('functions.x:0{"v": "{}"} ok') + flushed = f.flush() + full = out + flushed + assert "functions.x" not in full + assert "ok" in full + + def test_escaped_quote_inside_string(self): + f = _ToolCallTextFilter() + out = f.feed(r'functions.x{"v":"a\"b"} ok') + flushed = f.flush() + full = out + flushed + assert "functions.x" not in full + assert "ok" in full + + def test_special_token_with_tail_recursion(self): + """A special token in the middle of normal text — text on both + sides is preserved, the token is dropped.""" + f = _ToolCallTextFilter() + out = f.feed("alpha<|x|>beta") + flushed = f.flush() + full = out + flushed + assert "alpha" in full + assert "beta" in full + assert "<|x|>" not in full + + def test_in_call_nested_braces(self): + """Nested ``{`` inside the JSON args must increment depth so the + outer ``}`` doesn't close the call prematurely.""" + f = _ToolCallTextFilter() + out = f.feed('functions.x{"a":{"b":1}} done') + flushed = f.flush() + full = out + flushed + assert "functions.x" not in full + assert "done" in full + + def test_special_token_close_split_across_chunks(self): + """The ``|>`` close arrives in a later chunk than the ``<|`` open — + filter must finish in the ``in_special`` state mid-stream.""" + f = _ToolCallTextFilter() + f.feed("<|abc") + out = f.feed("def|>tail") + flushed = f.flush() + full = out + flushed + assert "tail" in full + assert "<|" not in full + + +# ============================================================================= +# _strip_special_tokens / _parse_inline_tool_calls / helpers +# ============================================================================= + + +class TestStripSpecialTokens: + def test_no_token_passes_through(self): + assert _strip_special_tokens("plain") == "plain" + + def test_strips_single_token(self): + assert _strip_special_tokens("a<|x|>b") == "ab" + + def test_unterminated_keeps_remainder(self): + result = _strip_special_tokens("a<|never closing") + assert "a" in result + + def test_empty_input(self): + assert _strip_special_tokens("") == "" + + +class TestParseInlineToolCalls: + def test_empty(self): + cleaned, calls = _parse_inline_tool_calls("") + assert cleaned == "" + assert calls == [] + + def test_no_marker_strips_special_tokens(self): + cleaned, calls = _parse_inline_tool_calls("hello <|tk|> world") + assert cleaned == "hello world" + assert calls == [] + + def test_one_call(self): + cleaned, calls = _parse_inline_tool_calls('go functions.foo{"a": 1} done') + assert "functions.foo" not in cleaned + assert len(calls) == 1 + assert calls[0]["name"] == "foo" + assert calls[0]["args"] == {"a": 1} + + def test_call_with_index_suffix(self): + _, calls = _parse_inline_tool_calls('functions.foo:42{"x": "y"}') + assert calls[0]["name"] == "foo" + assert calls[0]["args"] == {"x": "y"} + + def test_invalid_json_dropped(self): + _, calls = _parse_inline_tool_calls("functions.foo{42}") + assert calls == [] + + def test_array_args_dropped(self): + _, calls = _parse_inline_tool_calls("functions.foo{[1,2,3]}") + # ``{[1,2,3]}`` isn't valid JSON — call dropped + assert calls == [] + + def test_unterminated_payload_drops_tail(self): + cleaned, _ = _parse_inline_tool_calls('keep functions.foo{"a":1') + assert "keep" in cleaned + assert "functions" not in cleaned + + def test_no_name_keeps_one_char(self): + cleaned, _ = _parse_inline_tool_calls("functions.") + assert "f" in cleaned + + def test_marker_no_brace_keeps_text(self): + _, calls = _parse_inline_tool_calls("functions.foo done") + assert calls == [] + + +class TestConsumeOneInlineCall: + def test_no_name_returns_idx_plus_one(self): + out: list[str] = [] + next_i, call = _consume_one_inline_call("functions.@bar", 0, "functions.", out) + assert next_i == 1 + assert call is None + + def test_index_suffix_then_no_brace(self): + out: list[str] = [] + next_i, call = _consume_one_inline_call("functions.foo:5 ", 0, "functions.", out) + assert call is None + assert next_i is not None + + +class TestScanBalancedBraces: + def test_balanced(self): + assert _scan_balanced_braces("{a}", 0) == 3 + + def test_nested(self): + assert _scan_balanced_braces("{a{b}c}", 0) == 7 + + def test_string_with_brace(self): + s = '{"v": "}"}' + assert _scan_balanced_braces(s, 0) == len(s) + + def test_unterminated(self): + assert _scan_balanced_braces("{abc", 0) is None + + def test_escaped_quote(self): + s = '{"v":"a\\"b"}' + assert _scan_balanced_braces(s, 0) == len(s) + + +class TestTryParseCallArgs: + def test_valid_dict(self): + assert _try_parse_call_args('{"a": 1}') == {"a": 1} + + def test_invalid_json(self): + assert _try_parse_call_args("not-json") is None + + def test_non_dict_wrapped(self): + assert _try_parse_call_args("[1,2,3]") == {"value": [1, 2, 3]} + + +# ============================================================================= +# _coerce_todo_list / _extract_todos_from_tool_output +# ============================================================================= + + +class TestCoerceTodoList: + def test_list_of_dicts(self): + assert _coerce_todo_list([{"a": 1}, {"b": 2}]) == [{"a": 1}, {"b": 2}] + + def test_skips_non_dicts(self): + assert _coerce_todo_list([{"a": 1}, "junk", 3]) == [{"a": 1}] + + def test_empty_list_becomes_none(self): + assert _coerce_todo_list([]) is None + + def test_non_list_returns_none(self): + assert _coerce_todo_list("not a list") is None + assert _coerce_todo_list(None) is None + assert _coerce_todo_list({"foo": "bar"}) is None + + +class TestExtractTodosFromToolOutput: + def test_command_object_with_update(self): + class _Cmd: + update = {"todos": [{"content": "x"}]} + + assert _extract_todos_from_tool_output(_Cmd()) == [{"content": "x"}] + + def test_dict_with_nested_update(self): + out = {"update": {"todos": [{"x": 1}]}} + assert _extract_todos_from_tool_output(out) == [{"x": 1}] + + def test_dict_with_top_level_todos(self): + out = {"todos": [{"x": 2}]} + assert _extract_todos_from_tool_output(out) == [{"x": 2}] + + def test_plain_list(self): + assert _extract_todos_from_tool_output([{"a": 1}]) == [{"a": 1}] + + def test_json_string(self): + assert _extract_todos_from_tool_output('[{"a": 1}]') == [{"a": 1}] + + def test_invalid_json_string_returns_none(self): + assert _extract_todos_from_tool_output("not json {") is None + + def test_unsupported_type(self): + assert _extract_todos_from_tool_output(42) is None + + +# ============================================================================= +# Middleware helpers — _next_pending_plan_step / message rewriter +# ============================================================================= + + +class TestNextPendingPlanStep: + def test_no_messages(self): + assert _next_pending_plan_step({}) is None + assert _next_pending_plan_step({"messages": []}) is None + + def test_last_not_ai_message(self): + state = {"messages": [_NotAIMessage()], "todos": [{"status": "pending"}]} + assert _next_pending_plan_step(state) is None + + def test_pending_tool_calls_skips(self): + msg = AIMessage(tool_calls=[{"name": "x"}]) + state = {"messages": [msg], "todos": [{"status": "pending"}]} + assert _next_pending_plan_step(state) is None + + def test_no_todos(self): + msg = AIMessage() + assert _next_pending_plan_step({"messages": [msg]}) is None + assert _next_pending_plan_step({"messages": [msg], "todos": []}) is None + assert _next_pending_plan_step({"messages": [msg], "todos": "not a list"}) is None + + def test_failed_todo_blocks(self): + msg = AIMessage() + state = { + "messages": [msg], + "todos": [{"status": "completed"}, {"status": "failed"}], + } + assert _next_pending_plan_step(state) is None + + def test_skips_non_dict_todos(self): + msg = AIMessage() + state = { + "messages": [msg], + "todos": ["junk", {"status": "pending", "content": "next"}], + } + assert _next_pending_plan_step(state) == "next" + + def test_returns_first_non_completed(self): + msg = AIMessage() + state = { + "messages": [msg], + "todos": [ + {"status": "completed", "content": "done"}, + {"status": "pending", "content": "todo-1"}, + {"status": "pending", "content": "todo-2"}, + ], + } + assert _next_pending_plan_step(state) == "todo-1" + + def test_falls_back_to_title(self): + msg = AIMessage() + state = {"messages": [msg], "todos": [{"status": "pending", "title": "T"}]} + assert _next_pending_plan_step(state) == "T" + + def test_all_completed_returns_none(self): + msg = AIMessage() + state = { + "messages": [msg], + "todos": [{"status": "completed"}, {"status": "completed"}], + } + assert _next_pending_plan_step(state) is None + + +class TestFlattenMessageContent: + def test_string(self): + assert _flatten_message_content("hi") == "hi" + + def test_list_of_dicts(self): + assert _flatten_message_content([{"text": "a"}, {"text": "b"}]) == "ab" + + def test_list_of_non_dict(self): + assert _flatten_message_content([42, "x"]) == "42x" + + def test_unsupported_type(self): + assert _flatten_message_content(None) is None + + +class TestRewriteInlineToolCallMessage: + def test_non_ai_message_passthrough(self): + msg = _NotAIMessage() + assert _rewrite_inline_tool_call_message(msg) is msg + + def test_no_changes_returns_input(self): + msg = AIMessage(content="plain text", tool_calls=None) + out = _rewrite_inline_tool_call_message(msg) + assert out is msg + + def test_strips_tokens_and_appends_calls(self): + msg = AIMessage( + content='hello <|tok|>functions.f{"x":1} world', + tool_calls=[{"name": "existing", "args": {}}], + ) + out = _rewrite_inline_tool_call_message(msg) + assert out is not msg + assert "functions.f" not in out.content + names = [c["name"] for c in out.tool_calls] + assert "existing" in names + assert "f" in names + + def test_unsupported_content_passthrough(self): + msg = AIMessage(content=12345) + assert _rewrite_inline_tool_call_message(msg) is msg + + def test_special_tokens_only_rewrites(self): + """No ``functions.`` markup but ``<|...|>`` tokens — message is + still rewritten with cleaned content.""" + msg = AIMessage(content="hello <|tk|> world") + out = _rewrite_inline_tool_call_message(msg) + assert out is not msg + assert "<|tk|>" not in out.content + assert "hello" in out.content + assert "world" in out.content + + +class TestRewriteResponseMessages: + def test_response_with_list_result_rewrites(self): + class _Resp: + def __init__(self, result): + self.result = result + + plain = AIMessage(content="plain") + markup = AIMessage(content='hello functions.f{"x":1}') + resp = _Resp([plain, markup]) + out = _rewrite_response_messages(resp) + assert out is resp + assert isinstance(resp.result, list) + assert "functions.f" not in resp.result[1].content + + def test_response_without_list_result(self): + class _Resp: + result = None + + resp = _Resp() + assert _rewrite_response_messages(resp) is resp + + +# ============================================================================= +# Construction +# ============================================================================= + + class TestDeepagentProviderConstruction: - def test_with_pre_built_agent(self): + def test_with_pre_built_agent(self, provider_factory): agent = FakeAgent([]) - provider = DeepagentProvider(agent=agent) + provider = provider_factory(agent=agent) assert provider._agent is agent def test_without_agent_stores_params(self): @@ -71,90 +722,180 @@ def test_without_agent_stores_params(self): assert provider._agent is None assert provider._model == "openai:gpt-4o" + def test_default_recursion_limit_is_50(self): + provider = DeepagentProvider(model="openai:gpt-4o") + assert provider._recursion_limit == 50 + + def test_custom_recursion_limit(self): + provider = DeepagentProvider(model="openai:gpt-4o", recursion_limit=200) + assert provider._recursion_limit == 200 + + def test_mcp_servers_default_empty(self): + provider = DeepagentProvider(model="openai:gpt-4o") + assert provider._mcp_servers == {} + assert provider._mcp_tools == [] + + def test_mcp_servers_stored_on_init(self): + servers = { + "pywry": {"transport": "streamable_http", "url": "http://127.0.0.1:8765/mcp"}, + } + provider = DeepagentProvider(model="openai:gpt-4o", mcp_servers=servers) + assert provider._mcp_servers == servers + + +# ============================================================================= +# initialize() / new_session() / load_session() +# ============================================================================= + class TestDeepagentProviderInitialize: - @pytest.mark.asyncio - async def test_initialize_returns_capabilities(self): - agent = FakeAgent([]) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + async def test_returns_capabilities(self, provider_factory): + provider = provider_factory(agent=FakeAgent([])) caps = await provider.initialize(ClientCapabilities()) assert caps.prompt_capabilities is not None assert caps.prompt_capabilities.image is True - @pytest.mark.asyncio - async def test_initialize_with_checkpointer_enables_load(self): + async def test_with_checkpointer_enables_load(self, provider_factory): pytest.importorskip("langgraph") from langgraph.checkpoint.memory import MemorySaver - agent = FakeAgent([]) - provider = DeepagentProvider( - agent=agent, checkpointer=MemorySaver(), auto_checkpointer=False, auto_store=False - ) + provider = provider_factory(agent=FakeAgent([]), checkpointer=MemorySaver()) caps = await provider.initialize(ClientCapabilities()) assert caps.load_session is True - @pytest.mark.asyncio - async def test_initialize_without_checkpointer_disables_load(self): - agent = FakeAgent([]) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + async def test_without_checkpointer_disables_load(self, provider_factory): + provider = provider_factory(agent=FakeAgent([])) caps = await provider.initialize(ClientCapabilities()) assert caps.load_session is False + async def test_auto_creates_checkpointer_and_store(self): + """When ``auto_checkpointer=True`` / ``auto_store=True`` and the + provider has no agent yet, initialize() populates both side-effects.""" + provider = DeepagentProvider( + agent=MagicMock(), + auto_checkpointer=True, + auto_store=True, + ) + sentinel_cp = object() + sentinel_store = object() + provider._create_checkpointer = lambda: sentinel_cp # type: ignore[assignment] + provider._create_store = lambda: sentinel_store # type: ignore[assignment] + caps = await provider.initialize(ClientCapabilities()) + assert provider._checkpointer is sentinel_cp + assert provider._store is sentinel_store + assert caps.load_session is True + + async def test_builds_agent_when_none(self): + provider = DeepagentProvider(auto_checkpointer=False, auto_store=False) + sentinel = MagicMock() + provider._build_agent = lambda: sentinel # type: ignore[assignment] + await provider.initialize(ClientCapabilities()) + assert provider._agent is sentinel + class TestDeepagentProviderSessions: - @pytest.mark.asyncio - async def test_new_session_returns_id(self): - agent = FakeAgent([]) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + async def test_new_session_returns_id(self, provider_factory): + provider = provider_factory(agent=FakeAgent([])) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") assert sid.startswith("da_") - @pytest.mark.asyncio - async def test_load_nonexistent_session_raises(self): - agent = FakeAgent([]) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + async def test_load_nonexistent_session_raises(self, provider_factory): + provider = provider_factory(agent=FakeAgent([])) await provider.initialize(ClientCapabilities()) with pytest.raises(ValueError, match="not found"): await provider.load_session("nonexistent", "/tmp") + async def test_load_existing_returns_id(self): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + provider._sessions["sess1"] = "thread-A" + assert await provider.load_session("sess1", "/cwd") == "sess1" + + +class TestNewSessionMcpServers: + async def test_merges_stdio_descriptor(self, provider_factory): + provider = provider_factory(agent=FakeAgent([])) + await provider.initialize(ClientCapabilities()) + await provider.new_session( + "/tmp", + mcp_servers=[ + {"name": "fs", "command": "uvx", "args": ["mcp-server-filesystem", "/tmp"]}, + ], + ) + entry = provider._mcp_servers["fs"] + assert entry["transport"] == "stdio" + assert entry["command"] == "uvx" + assert entry["args"] == ["mcp-server-filesystem", "/tmp"] + # Mutating mcp_servers forces an agent + tool rebuild on next prompt + assert provider._agent is None + assert provider._mcp_tools == [] + + async def test_merges_http_descriptor(self, provider_factory): + provider = provider_factory(agent=FakeAgent([])) + await provider.initialize(ClientCapabilities()) + await provider.new_session( + "/tmp", + mcp_servers=[{"name": "pywry", "url": "http://127.0.0.1:8765/mcp"}], + ) + entry = provider._mcp_servers["pywry"] + assert entry["transport"] == "streamable_http" + assert entry["url"] == "http://127.0.0.1:8765/mcp" + + async def test_no_mcp_keeps_existing_agent(self, provider_factory): + agent = FakeAgent([]) + provider = provider_factory(agent=agent) + await provider.initialize(ClientCapabilities()) + await provider.new_session("/tmp") + assert provider._agent is agent + + async def test_skips_non_dict_entry(self): + provider = DeepagentProvider( + agent=MagicMock(), auto_checkpointer=False, auto_store=False + ) + await provider.initialize(ClientCapabilities()) + await provider.new_session("/tmp", mcp_servers=["junk", 42]) + assert provider._mcp_servers == {} + + async def test_no_name_falls_back_to_uuid_prefix(self): + provider = DeepagentProvider( + agent=MagicMock(), auto_checkpointer=False, auto_store=False + ) + await provider.initialize(ClientCapabilities()) + await provider.new_session("/tmp", mcp_servers=[{"command": "x", "args": []}]) + assert any(k.startswith("acp_") for k in provider._mcp_servers) + + +# ============================================================================= +# Streaming behaviour +# ============================================================================= + class TestDeepagentProviderStreaming: - @pytest.mark.asyncio - async def test_text_chunks(self): + async def test_text_chunks(self, provider_factory): events = [ make_event("on_chat_model_stream", data={"chunk": FakeChunk("hello ")}), make_event("on_chat_model_stream", data={"chunk": FakeChunk("world")}), ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - updates = [] - async for u in provider.prompt(sid, [TextPart(text="hi")]): - updates.append(u) - + updates = await _drain_prompt(provider, sid) assert len(updates) == 2 assert all(isinstance(u, AgentMessageUpdate) for u in updates) assert updates[0].text == "hello " assert updates[1].text == "world" - @pytest.mark.asyncio - async def test_tool_call_lifecycle(self): + async def test_tool_call_lifecycle(self, provider_factory): events = [ make_event("on_tool_start", name="read_file", run_id="tc1"), make_event("on_tool_end", name="read_file", run_id="tc1", data={"output": "contents"}), ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - updates = [] - async for u in provider.prompt(sid, [TextPart(text="read")]): - updates.append(u) - + updates = await _drain_prompt(provider, sid, "read") assert len(updates) == 2 assert isinstance(updates[0], ToolCallUpdate) assert updates[0].status == "in_progress" @@ -162,25 +903,19 @@ async def test_tool_call_lifecycle(self): assert isinstance(updates[1], ToolCallUpdate) assert updates[1].status == "completed" - @pytest.mark.asyncio - async def test_tool_error(self): + async def test_tool_error(self, provider_factory): events = [ make_event("on_tool_start", name="execute", run_id="tc2"), make_event("on_tool_error", name="execute", run_id="tc2"), ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - updates = [] - async for u in provider.prompt(sid, [TextPart(text="run")]): - updates.append(u) - + updates = await _drain_prompt(provider, sid, "run") assert updates[-1].status == "failed" - @pytest.mark.asyncio - async def test_write_todos_produces_plan_update(self): + async def test_write_todos_produces_plan_update(self, provider_factory): import json todos = [ @@ -190,18 +925,17 @@ async def test_write_todos_produces_plan_update(self): events = [ make_event("on_tool_start", name="write_todos", run_id="tc3"), make_event( - "on_tool_end", name="write_todos", run_id="tc3", data={"output": json.dumps(todos)} + "on_tool_end", + name="write_todos", + run_id="tc3", + data={"output": json.dumps(todos)}, ), ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - updates = [] - async for u in provider.prompt(sid, [TextPart(text="plan")]): - updates.append(u) - + updates = await _drain_prompt(provider, sid, "plan") plan_updates = [u for u in updates if isinstance(u, PlanUpdate)] assert len(plan_updates) == 1 assert len(plan_updates[0].entries) == 2 @@ -209,11 +943,13 @@ async def test_write_todos_produces_plan_update(self): assert plan_updates[0].entries[0].status == "completed" assert plan_updates[0].entries[1].status == "in_progress" - @pytest.mark.asyncio - async def test_write_todos_langgraph_command_output_produces_plan_update(self): + async def test_write_todos_langgraph_command_output(self, provider_factory): """Deep Agents' ``write_todos`` returns a LangGraph ``Command`` with ``update={"todos": [...]}`` — the extractor must pull the list out of that shape, not just the legacy plain-JSON list. + + The plan card IS the visualization — no raw Command repr should + double-render as a tool-call card. """ class _Command: @@ -233,15 +969,11 @@ def __init__(self, update: dict) -> None: data={"output": _Command(update={"todos": todos})}, ), ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - updates = [] - async for u in provider.prompt(sid, [TextPart(text="plan")]): - updates.append(u) - + updates = await _drain_prompt(provider, sid, "plan") plan_updates = [u for u in updates if isinstance(u, PlanUpdate)] assert len(plan_updates) == 1 assert [e.content for e in plan_updates[0].entries] == [ @@ -249,8 +981,6 @@ def __init__(self, update: dict) -> None: "Change interval to 1m", ] assert [e.status for e in plan_updates[0].entries] == ["completed", "in_progress"] - # The plan card IS the visualization — no raw Command repr should - # double-render as a tool-call card. tool_completed = [ u for u in updates @@ -258,77 +988,217 @@ def __init__(self, update: dict) -> None: ] assert tool_completed == [] - @pytest.mark.asyncio - async def test_cancel_stops_streaming(self): + async def test_writes_todos_emits_planning_status_only_at_start(self, provider_factory): + events = [make_event("on_tool_start", name="write_todos")] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + statuses = [u for u in out if isinstance(u, StatusUpdate)] + assert any("Planning" in s.text for s in statuses) + + async def test_cancel_stops_streaming(self, provider_factory): events = [ make_event("on_chat_model_stream", data={"chunk": FakeChunk(f"chunk{i}")}) for i in range(100) ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") cancel = asyncio.Event() updates = [] - count = 0 async for u in provider.prompt(sid, [TextPart(text="go")], cancel_event=cancel): updates.append(u) - count += 1 - if count == 3: + if len(updates) == 3: cancel.set() assert len(updates) < 100 - @pytest.mark.asyncio - async def test_chat_model_start_yields_status(self): + async def test_chat_model_start_yields_status(self, provider_factory): events = [ make_event("on_chat_model_start", name="ChatOpenAI"), make_event("on_chat_model_stream", data={"chunk": FakeChunk("answer")}), ] - agent = FakeAgent(events) - provider = DeepagentProvider(agent=agent, auto_checkpointer=False, auto_store=False) + provider = provider_factory(agent=FakeAgent(events)) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - updates = [] - async for u in provider.prompt(sid, [TextPart(text="hi")]): - updates.append(u) - + updates = await _drain_prompt(provider, sid) assert isinstance(updates[0], StatusUpdate) assert "ChatOpenAI" in updates[0].text assert isinstance(updates[1], AgentMessageUpdate) + async def test_chat_model_start_no_name_yields_thinking(self, provider_factory): + events = [make_event("on_chat_model_start", name="")] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + statuses = [u for u in out if isinstance(u, StatusUpdate)] + assert any(s.text == "Thinking..." for s in statuses) + + async def test_subagent_task_emits_status(self, provider_factory): + events = [make_event("on_chain_start", name="task")] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + statuses = [u for u in out if isinstance(u, StatusUpdate)] + assert any("Delegating to subagent" in s.text for s in statuses) + + async def test_chat_model_stream_yields_thinking_only(self, provider_factory): + events = [ + make_event( + "on_chat_model_stream", + data={"chunk": FakeChunk("", additional_kwargs={"reasoning_content": "internal"})}, + ) + ] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + thinking = [u for u in out if isinstance(u, ThinkingUpdate)] + assert thinking and thinking[0].text == "internal" + + async def test_prompt_builds_agent_lazily(self): + """When ``self._agent is None``, prompt() calls ``_build_agent``.""" + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + sentinel_agent = FakeAgent([]) + provider._build_agent = lambda: sentinel_agent # type: ignore[assignment] + sid = await provider.new_session("/tmp") + assert provider._agent is None + await _drain_prompt(provider, sid) + assert provider._agent is sentinel_agent + + async def test_chain_end_emits_safe_buffer_tail(self, provider_factory): + """An unsafe-prefix tail in the filter buffer (``"hi func"``) gets + emitted as the chain-end flush since neither ``in_call`` nor + ``in_special`` is true.""" + events = [ + make_event("on_chat_model_stream", data={"chunk": FakeChunk("hi func")}), + make_event("on_chain_end", data={}, parent_ids=[]), + ] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + full = "".join(getattr(u, "text", "") for u in out if isinstance(u, AgentMessageUpdate)) + assert "func" in full + + async def test_chain_end_flushes_dropping_unclosed_markup(self, provider_factory): + """A partial ``functions.`` prefix never completes — chain-end flush + drops it (the buffer is in ``in_call`` would-be state). Only the + safe text before the marker survives.""" + events = [ + make_event("on_chat_model_stream", data={"chunk": FakeChunk("hello fun")}), + make_event("on_chain_end", data={}, parent_ids=[]), + ] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + texts = [getattr(u, "text", "") for u in out if isinstance(u, AgentMessageUpdate)] + assert any("hello" in t for t in texts) + # ============================================================================= -# MCP integration / recursion_limit / truncate_session -# ============================================================================= +# _handle_tool_end edge cases +# ============================================================================= -class TestDeepagentProviderConstructor: - def test_default_recursion_limit_is_50(self): - provider = DeepagentProvider(model="openai:gpt-4o") - assert provider._recursion_limit == 50 +class TestHandleToolEnd: + async def test_object_content(self, provider_factory): + class _Out: + content = "tool-out-text" - def test_custom_recursion_limit(self): - provider = DeepagentProvider(model="openai:gpt-4o", recursion_limit=200) - assert provider._recursion_limit == 200 + events = [make_event("on_tool_end", name="ls", run_id="r1", data={"output": _Out()})] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") - def test_mcp_servers_default_empty(self): - provider = DeepagentProvider(model="openai:gpt-4o") - assert provider._mcp_servers == {} - assert provider._mcp_tools == [] + out = await _drain_prompt(provider, sid) + completed = [u for u in out if isinstance(u, ToolCallUpdate) and u.status == "completed"] + assert completed + assert completed[0].content[0]["text"] == "tool-out-text" - def test_mcp_servers_stored_on_init(self): - servers = { - "pywry": {"transport": "streamable_http", "url": "http://127.0.0.1:8765/mcp"}, - } - provider = DeepagentProvider(model="openai:gpt-4o", mcp_servers=servers) - assert provider._mcp_servers == servers + async def test_dict_output_json_encoded(self, provider_factory): + events = [make_event("on_tool_end", name="ls", run_id="r1", data={"output": {"a": 1}})] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + completed = [u for u in out if isinstance(u, ToolCallUpdate) and u.status == "completed"] + assert completed + assert '"a": 1' in completed[0].content[0]["text"] + + async def test_scalar_output_stringified(self, provider_factory): + events = [make_event("on_tool_end", name="ls", run_id="r1", data={"output": 42})] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + completed = [u for u in out if isinstance(u, ToolCallUpdate) and u.status == "completed"] + assert completed + assert "42" in completed[0].content[0]["text"] + + async def test_none_output_omits_content(self, provider_factory): + events = [make_event("on_tool_end", name="ls", run_id="r1", data={"output": None})] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + completed = [u for u in out if isinstance(u, ToolCallUpdate) and u.status == "completed"] + assert completed + assert completed[0].content is None + + async def test_no_run_id_uses_generated_id(self, provider_factory): + events = [make_event("on_tool_end", name="ls", run_id="", data={"output": "x"})] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + completed = [u for u in out if isinstance(u, ToolCallUpdate) and u.status == "completed"] + assert completed + assert completed[0].tool_call_id.startswith("call_") + + async def test_json_fallback_on_dump_failure(self, provider_factory): + """When the dict has a value whose ``__str__`` raises, ``json.dumps( + ..., default=str)`` raises — fall back to plain ``str(output)``.""" + + class _Boom: + def __str__(self): + raise RuntimeError("boom") + + events = [ + make_event("on_tool_end", name="ls", run_id="r1", data={"output": {"k": _Boom()}}), + ] + provider = provider_factory(agent=FakeAgent(events)) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + + out = await _drain_prompt(provider, sid) + completed = [u for u in out if isinstance(u, ToolCallUpdate) and u.status == "completed"] + assert completed + assert completed[0].content is not None + + +# ============================================================================= +# Config / recursion-limit +# ============================================================================= class TestRecursionLimitInPromptConfig: - @pytest.mark.asyncio async def test_recursion_limit_passed_in_config(self): captured: dict = {} @@ -350,83 +1220,368 @@ async def _empty(): ) await provider.initialize(ClientCapabilities()) sid = await provider.new_session("/tmp") - async for _ in provider.prompt(sid, [TextPart(text="hi")]): - pass + await _drain_prompt(provider, sid) assert captured["config"]["recursion_limit"] == 42 assert captured["config"]["configurable"]["thread_id"] -class TestNewSessionMcpServers: - @pytest.mark.asyncio - async def test_new_session_merges_stdio_descriptor(self): +# ============================================================================= +# Internal helpers — _create_checkpointer / _create_store / _load_mcp_tools +# ============================================================================= + + +class TestCreateCheckpointer: + def test_returns_memory_saver_when_no_backend(self, monkeypatch): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + # Force the state-backend probe to fail → fall through to MemorySaver + fake_state_factory = MagicMock() + fake_state_factory.get_state_backend = lambda: (_ for _ in ()).throw(RuntimeError("nope")) + monkeypatch.setitem(sys.modules, "pywry.state._factory", fake_state_factory) + result = provider._create_checkpointer() + assert result is not None # MemorySaver + + def test_returns_none_when_langgraph_missing(self, monkeypatch): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + fake_state_factory = MagicMock() + fake_state_factory.get_state_backend = lambda: (_ for _ in ()).throw(RuntimeError("nope")) + monkeypatch.setitem(sys.modules, "pywry.state._factory", fake_state_factory) + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "langgraph.checkpoint.memory": + raise ImportError("no langgraph") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + assert provider._create_checkpointer() is None + + def test_redis_path_uses_redis_saver(self, monkeypatch): + from pywry.state.types import StateBackend + + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + + fake_state_factory = MagicMock() + fake_state_factory.get_state_backend = lambda: StateBackend.REDIS + monkeypatch.setitem(sys.modules, "pywry.state._factory", fake_state_factory) + + sentinel = object() + fake_saver_module = MagicMock() + fake_saver_module.RedisSaver = lambda url: sentinel + monkeypatch.setitem(sys.modules, "langgraph.checkpoint.redis", fake_saver_module) + + fake_settings_obj = MagicMock() + fake_settings_obj.deploy.redis_url = "redis://localhost:6379" + fake_config_module = MagicMock() + fake_config_module.get_settings = lambda: fake_settings_obj + monkeypatch.setitem(sys.modules, "pywry.config", fake_config_module) + + assert provider._create_checkpointer() is sentinel + + def test_sqlite_path_uses_sqlite_saver(self, monkeypatch): + from pywry.state.types import StateBackend + + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + + fake_state_factory = MagicMock() + fake_state_factory.get_state_backend = lambda: StateBackend.SQLITE + monkeypatch.setitem(sys.modules, "pywry.state._factory", fake_state_factory) + + sentinel = object() + + class _SqliteSaver: + @classmethod + def from_conn_string(cls, _db_path): + return sentinel + + fake_saver_module = MagicMock() + fake_saver_module.SqliteSaver = _SqliteSaver + monkeypatch.setitem(sys.modules, "langgraph.checkpoint.sqlite", fake_saver_module) + + fake_settings_obj = MagicMock() + fake_settings_obj.deploy.sqlite_path = "/tmp/state.db" + fake_config_module = MagicMock() + fake_config_module.get_settings = lambda: fake_settings_obj + monkeypatch.setitem(sys.modules, "pywry.config", fake_config_module) + + assert provider._create_checkpointer() is sentinel + + def test_sqlite_path_falls_through_when_sqlite_missing(self, monkeypatch): + from pywry.state.types import StateBackend + + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + + fake_state_factory = MagicMock() + fake_state_factory.get_state_backend = lambda: StateBackend.SQLITE + monkeypatch.setitem(sys.modules, "pywry.state._factory", fake_state_factory) + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "langgraph.checkpoint.sqlite": + raise ImportError("no sqlite") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + # Should fall back to MemorySaver + assert provider._create_checkpointer() is not None + + +class TestCreateStore: + def test_returns_in_memory_store(self): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + assert provider._create_store() is not None + + def test_returns_none_when_langgraph_missing(self, monkeypatch): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "langgraph.store.memory": + raise ImportError("no langgraph") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + assert provider._create_store() is None + + +class TestLoadMcpTools: + def test_no_servers_returns_empty(self): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + assert provider._load_mcp_tools() == [] + + def test_no_adapter_warns_and_returns_empty(self, monkeypatch): provider = DeepagentProvider( - agent=FakeAgent([]), + model="x", + mcp_servers={"fs": {"transport": "stdio"}}, auto_checkpointer=False, auto_store=False, ) - await provider.initialize(ClientCapabilities()) - await provider.new_session( - "/tmp", - mcp_servers=[ - {"name": "fs", "command": "uvx", "args": ["mcp-server-filesystem", "/tmp"]}, - ], + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "langchain_mcp_adapters.client": + raise ImportError("no adapter") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + assert provider._load_mcp_tools() == [] + + def test_with_running_loop_uses_threadpool(self, monkeypatch): + """When called from inside an event loop, the implementation runs + ``client.get_tools()`` on a threadpool executor.""" + provider = DeepagentProvider( + model="x", + mcp_servers={"fs": {"transport": "stdio"}}, + auto_checkpointer=False, + auto_store=False, ) - assert "fs" in provider._mcp_servers - entry = provider._mcp_servers["fs"] - assert entry["transport"] == "stdio" - assert entry["command"] == "uvx" - assert entry["args"] == ["mcp-server-filesystem", "/tmp"] - # Forces a rebuild on next prompt - assert provider._agent is None - assert provider._mcp_tools == [] - @pytest.mark.asyncio - async def test_new_session_merges_http_descriptor(self): + async def _get_tools(): + return ["tool-A", "tool-B"] + + class _Client: + def __init__(self, _cfg): + pass + + def get_tools(self): + return _get_tools() + + fake_client_module = MagicMock() + fake_client_module.MultiServerMCPClient = _Client + monkeypatch.setitem(sys.modules, "langchain_mcp_adapters.client", fake_client_module) + + async def _runner(): + return provider._load_mcp_tools() + + assert asyncio.run(_runner()) == ["tool-A", "tool-B"] + + def test_get_tools_failure_returns_empty(self, monkeypatch): provider = DeepagentProvider( - agent=FakeAgent([]), + model="x", + mcp_servers={"fs": {"transport": "stdio"}}, auto_checkpointer=False, auto_store=False, ) - await provider.initialize(ClientCapabilities()) - await provider.new_session( - "/tmp", - mcp_servers=[ - {"name": "pywry", "url": "http://127.0.0.1:8765/mcp"}, - ], + + class _Client: + def __init__(self, _cfg): + pass + + def get_tools(self): + raise RuntimeError("boom") + + fake_client_module = MagicMock() + fake_client_module.MultiServerMCPClient = _Client + monkeypatch.setitem(sys.modules, "langchain_mcp_adapters.client", fake_client_module) + assert provider._load_mcp_tools() == [] + + +# ============================================================================= +# _build_agent_kwargs / _build_agent +# ============================================================================= + + +class TestBuildAgentKwargs: + def test_minimal_kwargs(self, monkeypatch): + # Stub middleware factories so they don't pull in langchain + monkeypatch.setattr(da, "_build_inline_tool_call_middleware", lambda: None) + monkeypatch.setattr(da, "_build_plan_continuation_middleware", lambda: None) + + provider = DeepagentProvider( + model="x", + auto_checkpointer=False, + auto_store=False, ) - entry = provider._mcp_servers["pywry"] - assert entry["transport"] == "streamable_http" - assert entry["url"] == "http://127.0.0.1:8765/mcp" + kwargs = provider._build_agent_kwargs([], "system-prompt") + assert kwargs["model"] == "x" + assert kwargs["system_prompt"] == "system-prompt" + # Middlewares are None and user list is empty → not added + assert "middleware" not in kwargs + + def test_full_kwargs_with_user_middleware(self, monkeypatch): + sentinel_inline = object() + sentinel_plan = object() + sentinel_user = object() + monkeypatch.setattr(da, "_build_inline_tool_call_middleware", lambda: sentinel_inline) + monkeypatch.setattr(da, "_build_plan_continuation_middleware", lambda: sentinel_plan) - @pytest.mark.asyncio - async def test_new_session_no_mcp_keeps_existing_agent(self): - agent = FakeAgent([]) provider = DeepagentProvider( - agent=agent, + model="x", + tools=[lambda: None], + interrupt_on={"tool": "ask"}, + backend="memory", + subagents=[{"name": "sub"}], + skills=["/path/SKILL.md"], + middleware=[sentinel_user], + checkpointer=object(), + store=object(), + memory=["/AGENTS.md"], auto_checkpointer=False, auto_store=False, ) - await provider.initialize(ClientCapabilities()) - await provider.new_session("/tmp") - # Without mcp_servers param the agent is preserved - assert provider._agent is agent + kwargs = provider._build_agent_kwargs(provider._tools, "p") + assert kwargs["middleware"] == [sentinel_inline, sentinel_plan, sentinel_user] + for key in ( + "tools", + "checkpointer", + "interrupt_on", + "backend", + "subagents", + "skills", + "store", + "memory", + ): + assert key in kwargs + + +class TestBuildAgent: + """Drive _build_agent through every system-prompt branch and MCP integration.""" + + def _patch_create_deep_agent(self, monkeypatch, captured): + import types + fake_module = types.ModuleType("deepagents") -class TestLoadMcpTools: - def test_returns_empty_when_no_servers_configured(self): - provider = DeepagentProvider(model="openai:gpt-4o") - assert provider._load_mcp_tools() == [] + def _fake_create_deep_agent(**kwargs): + captured.update(kwargs) + return MagicMock(name="agent") + + fake_module.create_deep_agent = _fake_create_deep_agent + monkeypatch.setitem(sys.modules, "deepagents", fake_module) + monkeypatch.setattr(da, "_build_inline_tool_call_middleware", lambda: None) + monkeypatch.setattr(da, "_build_plan_continuation_middleware", lambda: None) + def test_default_system_prompt_is_pywry(self, monkeypatch): + captured: dict = {} + self._patch_create_deep_agent(monkeypatch, captured) + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + provider._build_agent() + assert captured["system_prompt"] == da.PYWRY_SYSTEM_PROMPT -class TestTruncateSession: - def test_no_op_when_checkpointer_missing(self): + def test_appended_system_prompt(self, monkeypatch): + captured: dict = {} + self._patch_create_deep_agent(monkeypatch, captured) + provider = DeepagentProvider( + model="x", + system_prompt="be brief", + auto_checkpointer=False, + auto_store=False, + ) + provider._build_agent() + assert da.PYWRY_SYSTEM_PROMPT in captured["system_prompt"] + assert "be brief" in captured["system_prompt"] + + def test_replace_system_prompt(self, monkeypatch): + captured: dict = {} + self._patch_create_deep_agent(monkeypatch, captured) + provider = DeepagentProvider( + model="x", + system_prompt="only-this", + replace_system_prompt=True, + auto_checkpointer=False, + auto_store=False, + ) + provider._build_agent() + assert captured["system_prompt"] == "only-this" + assert da.PYWRY_SYSTEM_PROMPT not in captured["system_prompt"] + + def test_loads_mcp_tools_when_servers_configured(self, monkeypatch): + captured: dict = {} + self._patch_create_deep_agent(monkeypatch, captured) provider = DeepagentProvider( - model="openai:gpt-4o", + model="x", + mcp_servers={"fs": {"transport": "stdio"}}, auto_checkpointer=False, auto_store=False, ) - # Should not raise even without a checkpointer - provider.truncate_session("session-1", []) + provider._load_mcp_tools = lambda: ["tool-X"] # type: ignore[assignment] + provider._build_agent() + assert "tool-X" in captured["tools"] + + +class TestAutoCheckpointerInBuildAgent: + """The auto-checkpointer must be set up by _build_agent so callers that + bypass the async initialize() still get conversation persistence.""" + + def test_build_agent_creates_checkpointer_when_missing(self, monkeypatch): + import types + + provider = DeepagentProvider(model="openai:gpt-4o", auto_checkpointer=True) + assert provider._checkpointer is None + + fake_module = types.ModuleType("deepagents") + fake_module.create_deep_agent = lambda **kwargs: object() + monkeypatch.setitem(sys.modules, "deepagents", fake_module) + + provider._build_agent() + assert provider._checkpointer is not None + + def test_build_agent_does_not_overwrite_existing_checkpointer(self, monkeypatch): + import types + + sentinel = object() + provider = DeepagentProvider( + model="openai:gpt-4o", checkpointer=sentinel, auto_checkpointer=True + ) + + fake_module = types.ModuleType("deepagents") + fake_module.create_deep_agent = lambda **kwargs: object() + monkeypatch.setitem(sys.modules, "deepagents", fake_module) + + provider._build_agent() + assert provider._checkpointer is sentinel + + +# ============================================================================= +# truncate_session +# ============================================================================= + + +class TestTruncateSession: + def test_no_op_when_checkpointer_missing(self): + provider = DeepagentProvider(model="x", auto_checkpointer=False, auto_store=False) + provider.truncate_session("session-1", []) # no-op, no raise def test_calls_delete_thread_when_available(self): deleted: list[str] = [] @@ -436,10 +1591,7 @@ def delete_thread(self, thread_id: str) -> None: deleted.append(thread_id) provider = DeepagentProvider( - model="openai:gpt-4o", - checkpointer=_Saver(), - auto_checkpointer=False, - auto_store=False, + model="x", checkpointer=_Saver(), auto_checkpointer=False, auto_store=False ) provider._sessions["sess-1"] = "thread-A" provider.truncate_session("sess-1", []) @@ -452,57 +1604,230 @@ def __init__(self) -> None: saver = _DictSaver() provider = DeepagentProvider( - model="openai:gpt-4o", - checkpointer=saver, - auto_checkpointer=False, - auto_store=False, + model="x", checkpointer=saver, auto_checkpointer=False, auto_store=False ) provider._sessions["sess-1"] = "thread-A" provider.truncate_session("sess-1", []) assert "thread-A" not in saver.storage assert "thread-B" in saver.storage # other threads untouched + def test_falls_back_to_adelete_thread_when_sync_missing(self): + deleted: list[str] = [] -class TestAutoCheckpointerInBuildAgent: - """The auto-checkpointer must be set up by _build_agent so callers that - bypass the async initialize() still get conversation persistence.""" + class _AsyncSaver: + async def adelete_thread(self, thread_id: str) -> None: + deleted.append(thread_id) - def test_build_agent_creates_checkpointer_when_missing(self, monkeypatch): - # Pre-empt the actual create_deep_agent import; we only care about - # the side-effect on self._checkpointer. provider = DeepagentProvider( - model="openai:gpt-4o", - auto_checkpointer=True, + model="x", checkpointer=_AsyncSaver(), auto_checkpointer=False, auto_store=False ) - assert provider._checkpointer is None + provider._sessions["s1"] = "thread-A" + provider.truncate_session("s1", []) + assert deleted == ["thread-A"] - # Patch create_deep_agent to a stub so _build_agent doesn't need - # the real deepagents package. - import sys - import types + def test_remaps_session_id_after_dict_pop(self): + class _DictSaver: + def __init__(self): + self.storage = {"thread-A": "junk"} - fake_module = types.ModuleType("deepagents") - fake_module.create_deep_agent = lambda **kwargs: object() - monkeypatch.setitem(sys.modules, "deepagents", fake_module) + saver = _DictSaver() + provider = DeepagentProvider( + model="x", checkpointer=saver, auto_checkpointer=False, auto_store=False + ) + provider._sessions["s1"] = "thread-A" + provider.truncate_session("s1", []) + assert "thread-A" not in saver.storage + assert provider._sessions["s1"].startswith("thread-A:") - provider._build_agent() - # Checkpointer was set as a side-effect - assert provider._checkpointer is not None + def test_delete_thread_exception_falls_through_to_adelete(self): + deleted: list[str] = [] + + class _BoomSaver: + def delete_thread(self, _tid: str) -> None: + raise RuntimeError("boom") + + def adelete_thread(self, tid: str): + async def _go(): + deleted.append(tid) + + return _go() - def test_build_agent_does_not_overwrite_existing_checkpointer(self, monkeypatch): - sentinel = object() provider = DeepagentProvider( - model="openai:gpt-4o", - checkpointer=sentinel, - auto_checkpointer=True, + model="x", checkpointer=_BoomSaver(), auto_checkpointer=False, auto_store=False ) + provider._sessions["s1"] = "thread-A" + provider.truncate_session("s1", []) + assert deleted == ["thread-A"] - import sys - import types + def test_async_inside_running_loop_uses_threadpool(self): + """Calling truncate_session from inside a running loop dispatches the + async deletion via a dedicated thread.""" + deleted: list[str] = [] - fake_module = types.ModuleType("deepagents") - fake_module.create_deep_agent = lambda **kwargs: object() - monkeypatch.setitem(sys.modules, "deepagents", fake_module) + class _AsyncSaver: + async def adelete_thread(self, tid: str) -> None: + deleted.append(tid) - provider._build_agent() - assert provider._checkpointer is sentinel + provider = DeepagentProvider( + model="x", checkpointer=_AsyncSaver(), auto_checkpointer=False, auto_store=False + ) + provider._sessions["s1"] = "thread-A" + + async def _run(): + provider.truncate_session("s1", []) + + asyncio.run(_run()) + assert deleted == ["thread-A"] + + def test_adelete_exception_falls_through_to_dict_pop(self): + class _Saver: + def __init__(self): + self.storage = {"thread-A": "x"} + + async def adelete_thread(self, _tid: str) -> None: + raise RuntimeError("adelete failed") + + saver = _Saver() + provider = DeepagentProvider( + model="x", checkpointer=saver, auto_checkpointer=False, auto_store=False + ) + provider._sessions["s1"] = "thread-A" + provider.truncate_session("s1", []) + assert "thread-A" not in saver.storage + + +# ============================================================================= +# Middleware factories — only exercised when langchain is installed +# ============================================================================= + + +class TestBuildMiddlewares: + """The middleware factories return None when langchain is missing and + behave as cached singletons otherwise.""" + + def test_inline_tool_call_middleware_returns_singleton(self, monkeypatch): + monkeypatch.setattr(da, "_inline_tool_call_middleware_singleton", None) + first = da._build_inline_tool_call_middleware() + if first is not None: + second = da._build_inline_tool_call_middleware() + assert first is second + + def test_inline_tool_call_middleware_returns_none_when_langchain_missing(self, monkeypatch): + monkeypatch.setattr(da, "_inline_tool_call_middleware_singleton", None) + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "langchain.agents.middleware": + raise ImportError("no langchain agents") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + assert da._build_inline_tool_call_middleware() is None + + def test_plan_continuation_middleware_returns_singleton(self, monkeypatch): + monkeypatch.setattr(da, "_plan_middleware_singleton", None) + first = da._build_plan_continuation_middleware() + if first is not None: + second = da._build_plan_continuation_middleware() + assert first is second + + def test_plan_continuation_middleware_returns_none_when_langchain_missing(self, monkeypatch): + monkeypatch.setattr(da, "_plan_middleware_singleton", None) + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "langchain.agents.middleware": + raise ImportError("no langchain") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + assert da._build_plan_continuation_middleware() is None + + def test_plan_continuation_after_model_returns_nudge(self, monkeypatch): + monkeypatch.setattr(da, "_plan_middleware_singleton", None) + mw = da._build_plan_continuation_middleware() + if mw is None: + pytest.skip("langchain.agents.middleware not installed") + msg = AIMessage() + state = { + "messages": [msg], + "todos": [{"status": "pending", "content": "do-it"}], + } + result = mw.after_model(state, runtime=None) + assert result is not None + assert result["jump_to"] == "model" + assert "do-it" in result["messages"][0].content + + def test_plan_continuation_after_model_returns_none_when_no_pending(self, monkeypatch): + monkeypatch.setattr(da, "_plan_middleware_singleton", None) + mw = da._build_plan_continuation_middleware() + if mw is None: + pytest.skip("langchain.agents.middleware not installed") + assert mw.after_model({}, runtime=None) is None + + def test_inline_tool_call_middleware_wrap_model_call_sync(self, monkeypatch): + monkeypatch.setattr(da, "_inline_tool_call_middleware_singleton", None) + mw = da._build_inline_tool_call_middleware() + if mw is None: + pytest.skip("langchain.agents.middleware not installed") + + class _Resp: + def __init__(self, result): + self.result = result + + msg = AIMessage(content="plain") + + def handler(_request): + return _Resp([msg]) + + out = mw.wrap_model_call(request="ignored", handler=handler) + assert out.result == [msg] + + def test_inline_tool_call_middleware_awrap_model_call(self, monkeypatch): + monkeypatch.setattr(da, "_inline_tool_call_middleware_singleton", None) + mw = da._build_inline_tool_call_middleware() + if mw is None: + pytest.skip("langchain.agents.middleware not installed") + + class _Resp: + def __init__(self, result): + self.result = result + + msg = AIMessage(content="plain") + + async def handler(_request): + return _Resp([msg]) + + async def _run(): + return await mw.awrap_model_call(request="r", handler=handler) + + assert asyncio.run(_run()).result == [msg] + + +# ============================================================================= +# prompt() finally block — aclose error swallowed +# ============================================================================= + + +class TestPromptFinallyBlock: + async def test_aclose_exception_swallowed(self, provider_factory): + class _BadIter: + def __aiter__(self): + return self + + async def __anext__(self): + raise StopAsyncIteration + + async def aclose(self): + raise RuntimeError("aclose-failure") + + class _Agent: + def astream_events(self, _payload, config=None, version="v2"): + return _BadIter() + + provider = provider_factory(agent=_Agent()) + await provider.initialize(ClientCapabilities()) + sid = await provider.new_session("/tmp") + # The aclose error inside the finally block is logged and swallowed + async for _ in provider.prompt(sid, [TextPart(text="hi")]): + pass diff --git a/pywry/tests/test_e2e_deploy_mode.py b/pywry/tests/test_e2e_deploy_mode.py index 5596594..ceb91b8 100644 --- a/pywry/tests/test_e2e_deploy_mode.py +++ b/pywry/tests/test_e2e_deploy_mode.py @@ -33,13 +33,12 @@ InvalidStatusCode = None from pywry.config import clear_settings -from pywry.inline import HAS_FASTAPI, _start_server, _state, stop_server +from pywry.inline import _start_server, _state, stop_server + from pywry.state._factory import clear_state_caches -# Skip if FastAPI not installed pytestmark = [ - pytest.mark.skipif(not HAS_FASTAPI, reason="FastAPI not installed"), pytest.mark.redis, pytest.mark.container, ] diff --git a/pywry/tests/test_e2e_rbac_widgets.py b/pywry/tests/test_e2e_rbac_widgets.py index d693f4e..07a5a29 100644 --- a/pywry/tests/test_e2e_rbac_widgets.py +++ b/pywry/tests/test_e2e_rbac_widgets.py @@ -21,13 +21,12 @@ import pytest from pywry.config import clear_settings -from pywry.inline import HAS_FASTAPI, _start_server, _state, stop_server +from pywry.inline import _start_server, _state, stop_server + from pywry.state._factory import clear_state_caches -# Skip if FastAPI not installed pytestmark = [ - pytest.mark.skipif(not HAS_FASTAPI, reason="FastAPI not installed"), pytest.mark.redis, pytest.mark.container, ] diff --git a/pywry/tests/test_error_paths.py b/pywry/tests/test_error_paths.py deleted file mode 100644 index 884543c..0000000 --- a/pywry/tests/test_error_paths.py +++ /dev/null @@ -1,583 +0,0 @@ -"""Tests for error paths and edge cases. - -These tests verify proper error handling for: -- Invalid configuration values -- Malformed callback data -- Missing/invalid resources -- Boundary conditions -- Type validation errors - -Run with: pytest tests/test_error_paths.py -v -""" - -from __future__ import annotations - -import pytest - -from pydantic import ValidationError - - -# ============================================================================= -# Invalid Configuration Tests -# ============================================================================= - - -class TestInvalidWindowConfig: - """Tests for invalid WindowConfig values.""" - - def test_width_below_minimum_raises(self) -> None: - """Width below 200 raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(width=100) # min is 200 - - def test_height_below_minimum_raises(self) -> None: - """Height below 150 raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(height=100) # min is 150 - - def test_negative_width_raises(self) -> None: - """Negative width raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(width=-100) - - def test_negative_height_raises(self) -> None: - """Negative height raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(height=-100) - - def test_min_width_below_minimum_raises(self) -> None: - """min_width below 100 raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(min_width=50) # min is 100 - - def test_min_height_below_minimum_raises(self) -> None: - """min_height below 100 raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(min_height=50) # min is 100 - - def test_invalid_theme_string_raises(self) -> None: - """Invalid theme string raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(theme="invalid_theme") # type: ignore - - def test_invalid_plotly_theme_string_raises(self) -> None: - """Invalid plotly_theme string raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(plotly_theme="invalid_plotly_theme") # type: ignore - - def test_min_width_greater_than_width_raises(self) -> None: - """min_width > width raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(width=300, min_width=500) - - def test_min_height_greater_than_height_raises(self) -> None: - """min_height > height raises validation error.""" - from pywry.models import WindowConfig - - with pytest.raises(ValidationError): - WindowConfig(height=200, min_height=400) - - -class TestInvalidSecuritySettings: - """Tests for invalid SecuritySettings values.""" - - def test_empty_default_src_is_allowed(self) -> None: - """Empty default_src is allowed (permissive).""" - from pywry.config import SecuritySettings - - # Empty string is technically valid, just insecure - settings = SecuritySettings(default_src="") - assert settings.default_src == "" - - def test_none_default_src_uses_default(self) -> None: - """None default_src falls back to default value.""" - from pywry.config import SecuritySettings - - settings = SecuritySettings() - assert settings.default_src != "" - - -class TestInvalidAssetSettings: - """Tests for invalid AssetSettings values.""" - - def test_invalid_plotly_version_format(self) -> None: - """Invalid plotly version format is accepted (no validation).""" - from pywry.config import AssetSettings - - # Version strings aren't validated, user's responsibility - settings = AssetSettings(plotly_version="not-a-version") - assert settings.plotly_version == "not-a-version" - - -# ============================================================================= -# Invalid Model Data Tests -# ============================================================================= - - -class TestInvalidHtmlContent: - """Tests for invalid HtmlContent values.""" - - def test_empty_html_allowed(self) -> None: - """Empty HTML is allowed.""" - from pywry.models import HtmlContent - - content = HtmlContent(html="") - assert content.html == "" - - def test_none_html_raises(self) -> None: - """None HTML raises validation error.""" - from pywry.models import HtmlContent - - with pytest.raises(ValidationError): - HtmlContent(html=None) # type: ignore - - def test_invalid_json_data_type_raises(self) -> None: - """Non-dict json_data raises validation error.""" - from pywry.models import HtmlContent - - with pytest.raises(ValidationError): - HtmlContent(html="
", json_data="not a dict") # type: ignore - - -class TestInvalidGridModels: - """Tests for Grid model validation.""" - - def test_coldef_field_can_be_none(self) -> None: - """None field is allowed (field is optional).""" - from pywry.grid import ColDef - - col = ColDef(field=None) - assert col.field is None - - def test_coldef_empty_field_allowed(self) -> None: - """Empty string field is allowed (user's responsibility).""" - from pywry.grid import ColDef - - col = ColDef(field="") - assert col.field == "" - - def test_coldef_with_valid_field(self) -> None: - """Valid field name works correctly.""" - from pywry.grid import ColDef - - col = ColDef(field="myColumn") - assert col.field == "myColumn" - - def test_coldef_negative_width_raises(self) -> None: - """Negative width raises validation error.""" - from pywry.grid import ColDef - - with pytest.raises(ValidationError): - ColDef(field="test", width=-100) - - def test_coldef_negative_min_width_raises(self) -> None: - """Negative min_width raises validation error.""" - from pywry.grid import ColDef - - with pytest.raises(ValidationError): - ColDef(field="test", min_width=-50) - - def test_coldef_negative_max_width_raises(self) -> None: - """Negative max_width raises validation error.""" - from pywry.grid import ColDef - - with pytest.raises(ValidationError): - ColDef(field="test", max_width=-50) - - def test_coldef_zero_width_allowed(self) -> None: - """Zero width is allowed (just not visible).""" - from pywry.grid import ColDef - - col = ColDef(field="test", width=0) - assert col.width == 0 - - -class TestInvalidToolbarModels: - """Tests for invalid Toolbar model values.""" - - def test_button_empty_label_uses_default(self) -> None: - """Empty button label uses 'Button' default in HTML.""" - from pywry.toolbar import Button - - btn = Button(label="", event="toolbar:click") - html = btn.build_html() - assert "Button" in html # Default label used - - def test_button_empty_event_raises(self) -> None: - """Empty event raises validation error.""" - from pywry.toolbar import Button - - with pytest.raises(ValidationError): - Button(label="Test", event="") - - def test_button_invalid_event_format_raises(self) -> None: - """Invalid event format raises validation error.""" - from pywry.toolbar import Button - - with pytest.raises(ValidationError): - Button(label="Test", event="no_namespace") - - def test_select_empty_options_allowed(self) -> None: - """Select without options is allowed.""" - from pywry.toolbar import Select - - sel = Select(event="view:change", options=[]) - assert len(sel.options) == 0 - - def test_toolbar_invalid_position_raises(self) -> None: - """Invalid toolbar position raises validation error.""" - from pywry.toolbar import Toolbar - - with pytest.raises(ValidationError): - Toolbar(position="invalid", items=[]) # type: ignore - - def test_text_input_negative_debounce_raises(self) -> None: - """Negative debounce raises validation error.""" - from pywry.toolbar import TextInput - - with pytest.raises(ValidationError): - TextInput(event="search:query", debounce=-1) - - def test_slider_min_greater_than_max_raises(self) -> None: - """Slider with min > max raises validation error.""" - from pywry.toolbar import SliderInput - - with pytest.raises(ValidationError): - SliderInput(event="zoom:level", min=100, max=0) - - def test_slider_value_out_of_range_raises(self) -> None: - """Slider with value outside min/max raises validation error.""" - from pywry.toolbar import SliderInput - - with pytest.raises(ValidationError): - SliderInput(event="zoom:level", value=150, min=0, max=100) - - def test_range_start_greater_than_end_raises(self) -> None: - """Range with start > end raises validation error.""" - from pywry.toolbar import RangeInput - - with pytest.raises(ValidationError): - RangeInput(event="filter:range", start=100, end=0) - - def test_range_min_greater_than_max_raises(self) -> None: - """Range with min > max raises validation error.""" - from pywry.toolbar import RangeInput - - with pytest.raises(ValidationError): - RangeInput(event="filter:range", min=100, max=0) - - def test_range_start_out_of_range_raises(self) -> None: - """Range with start outside min/max raises validation error.""" - from pywry.toolbar import RangeInput - - with pytest.raises(ValidationError): - RangeInput(event="filter:range", start=-50, min=0, max=100) - - -# ============================================================================= -# Callback/Event Error Tests -# ============================================================================= - - -class TestCallbackErrors: - """Tests for callback registry error handling.""" - - def test_dispatch_nonexistent_label_returns_false(self) -> None: - """Dispatching to nonexistent label returns False.""" - from pywry.callbacks import get_registry - - registry = get_registry() - registry.clear() - - result = registry.dispatch("nonexistent-label", "test:event", {}) - assert result is False - - def test_dispatch_nonexistent_event_returns_false(self) -> None: - """Dispatching nonexistent event returns False.""" - from pywry.callbacks import get_registry - - registry = get_registry() - registry.clear() - registry.register("test-window", "known:event", lambda _: None) - - result = registry.dispatch("test-window", "unknown:event", {}) - assert result is False - - def test_unregister_nonexistent_callback_returns_false(self) -> None: - """Unregistering nonexistent callback returns False.""" - from pywry.callbacks import get_registry - - registry = get_registry() - registry.clear() - - result = registry.unregister("nonexistent", "test:event", lambda _: None) - assert result is False - - def test_register_invalid_event_type_returns_false(self) -> None: - """Registering invalid event type returns False.""" - from pywry.callbacks import get_registry - - registry = get_registry() - registry.clear() - - result = registry.register("test-window", "invalid", lambda _: None) - assert result is False - - -class TestEventValidationErrors: - """Tests for event type validation errors.""" - - def test_validate_empty_event_type(self) -> None: - """Empty event type is invalid.""" - from pywry.models import validate_event_type - - assert validate_event_type("") is False - - def test_validate_no_namespace(self) -> None: - """Event without namespace is invalid.""" - from pywry.models import validate_event_type - - assert validate_event_type("click") is False - - def test_validate_empty_namespace(self) -> None: - """Empty namespace is invalid.""" - from pywry.models import validate_event_type - - assert validate_event_type(":click") is False - - def test_validate_empty_event_name(self) -> None: - """Empty event name is invalid.""" - from pywry.models import validate_event_type - - assert validate_event_type("toolbar:") is False - - -# ============================================================================= -# Exception Hierarchy Tests -# ============================================================================= - - -class TestExceptionChaining: - """Tests for exception chaining and context preservation.""" - - def test_window_error_preserves_context(self) -> None: - """WindowError preserves context in message.""" - from pywry.exceptions import WindowError - - exc = WindowError("Failed", label="test-win", operation="close") - exc_str = str(exc) - assert "test-win" in exc_str - assert "close" in exc_str - - def test_ipc_timeout_includes_timeout_value(self) -> None: - """IPCTimeoutError includes timeout value in message.""" - from pywry.exceptions import IPCTimeoutError - - exc = IPCTimeoutError("Timed out", timeout=5.0) - exc_str = str(exc) - assert "5.0" in exc_str - - def test_exception_can_be_caught_as_base(self) -> None: - """Specific exceptions can be caught as base PyWryException.""" - from pywry.exceptions import PyWryException, WindowError - - try: - raise WindowError("Test error", label="win") - except PyWryException as e: - assert "Test error" in str(e) - - -# ============================================================================= -# Boundary Condition Tests -# ============================================================================= - - -class TestBoundaryConditions: - """Tests for boundary conditions and edge cases.""" - - def test_very_large_window_dimensions(self) -> None: - """Very large window dimensions are accepted.""" - from pywry.models import WindowConfig - - config = WindowConfig(width=10000, height=10000) - assert config.width == 10000 - assert config.height == 10000 - - def test_minimum_valid_dimensions(self) -> None: - """Minimum valid dimensions are accepted (width=200, height=150).""" - from pywry.models import WindowConfig - - # Minimum allowed values based on model validation - config = WindowConfig(width=200, height=150, min_width=100, min_height=100) - assert config.width == 200 - assert config.height == 150 - - def test_very_long_title(self) -> None: - """Very long title is accepted.""" - from pywry.models import WindowConfig - - long_title = "A" * 1000 - config = WindowConfig(title=long_title) - assert config.title == long_title - - def test_unicode_in_title(self) -> None: - """Unicode characters in title are accepted.""" - from pywry.models import WindowConfig - - config = WindowConfig(title="测试窗口 🪟") - assert config.title == "测试窗口 🪟" - - def test_unicode_in_html_content(self) -> None: - """Unicode characters in HTML content are accepted.""" - from pywry.models import HtmlContent - - content = HtmlContent(html="
こんにちは 🌍
") - assert "こんにちは" in content.html - - def test_special_characters_in_json_data(self) -> None: - """Special characters in JSON data are preserved.""" - from pywry.models import HtmlContent - - content = HtmlContent( - html="
", - json_data={"message": "Hello "}, - ) - json_data = content.json_data - assert json_data is not None - assert ""}, + ) + json_data = content.json_data + assert json_data is not None + assert "" in result + + def test_uses_provided_loader(self, tmp_path): + """When a loader is provided, it's used in place of the default.""" + from pywry.asset_loader import AssetLoader + from pywry.templates import build_custom_scripts + + js_file = tmp_path / "x.js" + js_file.write_text("var x = 1;") + loader = AssetLoader(base_dir=tmp_path) + content = HtmlContent(html="
", script_files=["x.js"]) + result = build_custom_scripts(content, loader=loader) + assert "var x = 1;" in result + + +class TestBuildGlobalCssNoPath: + """Tests for build_global_css/build_global_scripts when settings.path is empty (488, 527).""" + + def test_global_css_without_path_uses_default_loader(self, tmp_path, monkeypatch): + """When settings.path is empty, default get_asset_loader is used (line 488).""" + from pywry import templates as tmpl + from pywry.asset_loader import AssetLoader + + css_file = tmp_path / "global.css" + css_file.write_text("/* default loader */") + default_loader = AssetLoader(base_dir=tmp_path) + monkeypatch.setattr( + "pywry.asset_loader.get_asset_loader", lambda: default_loader + ) + settings = AssetSettings(css_files=["global.css"]) + result = tmpl.build_global_css(settings) + assert "/* default loader */" in result + + def test_global_scripts_without_path_uses_default_loader(self, tmp_path, monkeypatch): + """When settings.path is empty, default get_asset_loader is used (line 527).""" + from pywry import templates as tmpl + from pywry.asset_loader import AssetLoader + + js_file = tmp_path / "global.js" + js_file.write_text("/* default scripts */") + default_loader = AssetLoader(base_dir=tmp_path) + monkeypatch.setattr( + "pywry.asset_loader.get_asset_loader", lambda: default_loader + ) + settings = AssetSettings(script_files=["global.js"]) + result = tmpl.build_global_scripts(settings) + assert "/* default scripts */" in result + + +class TestAddThemeClassToHtmlTag: + """Tests for _add_theme_class_to_html_tag (lines 600-616).""" + + def test_adds_class_when_no_class_attribute(self): + """Adds class attribute when none exists.""" + from pywry.templates import _add_theme_class_to_html_tag + + result = _add_theme_class_to_html_tag("", "pywry-theme-dark") + assert 'class="pywry-theme-dark"' in result + + def test_appends_to_existing_class(self): + """Appends new class to existing class attribute.""" + from pywry.templates import _add_theme_class_to_html_tag + + result = _add_theme_class_to_html_tag( + '', "pywry-theme-dark" + ) + assert 'class="existing pywry-theme-dark"' in result + + def test_does_not_duplicate_class(self): + """Does not add class if already present.""" + from pywry.templates import _add_theme_class_to_html_tag + + result = _add_theme_class_to_html_tag( + '', "pywry-theme-dark" + ) + # Class should appear exactly once. + assert result.count("pywry-theme-dark") == 1 + + def test_preserves_other_attributes(self): + """Preserves other attributes on the html tag.""" + from pywry.templates import _add_theme_class_to_html_tag + + result = _add_theme_class_to_html_tag( + '', "pywry-theme-light" + ) + assert 'lang="en"' in result + assert "pywry-theme-light" in result + + +class TestInjectModalBeforeBodyClose: + """Tests for _inject_modal_before_body_close (lines 621-626).""" + + def test_returns_html_unchanged_when_modal_empty(self): + """When modal_html is empty, html is returned unchanged.""" + from pywry.templates import _inject_modal_before_body_close + + html = "x" + assert _inject_modal_before_body_close(html, "") == html + + def test_injects_modal_before_body_close(self): + """Modal HTML is inserted just before .""" + from pywry.templates import _inject_modal_before_body_close + + result = _inject_modal_before_body_close( + "main", '' + ) + assert 'main' in result + + def test_no_body_close_returns_html_unchanged(self): + """When is missing, html is returned unchanged.""" + from pywry.templates import _inject_modal_before_body_close + + html = "orphan" + result = _inject_modal_before_body_close(html, "
") + assert result == html + + +class TestInjectIntoCompleteDoc: + """Tests for _inject_into_complete_doc (lines 631, 653-670, 809).""" + + def test_complete_doc_with_head_keeps_doctype(self): + """A complete document with has scripts injected before .""" + user_html = ( + "Mine" + "

Body

" + ) + config = WindowConfig() + content = HtmlContent(html=user_html) + result = build_html(content, config, window_label="main") + # Doctype preserved + assert result.startswith("") + # User title preserved + assert "Mine" in result + # Body content preserved + assert "

Body

" in result + + def test_complete_doc_no_head_gets_head_inserted(self): + """A complete document missing has one inserted (line 669).""" + user_html = ( + "

NoHead

" + ) + config = WindowConfig() + content = HtmlContent(html=user_html) + result = build_html(content, config, window_label="main") + # A head with the CSP meta tag should now exist. + assert "" in result + assert "Content-Security-Policy" in result + assert "

NoHead

" in result + + def test_complete_doc_doctype_only_returns_html_unchanged(self): + """A 'complete doc' starting with falls through (line 670).""" + from pywry.templates import _inject_into_complete_doc + + # Pathological input: with no and no . + user_html = "just text" + components = { + "csp_meta": "", + "base_styles": "", + "json_script": "", + "plotly_script": "", + "aggrid_script": "", + "tvchart_script": "", + "init_script": "", + "toolbar_script": "", + "modal_scripts": "", + "custom_css": "", + "custom_scripts": "", + "global_css": "", + "global_scripts": "", + "custom_init": "", + } + result = _inject_into_complete_doc( + user_html, "pywry-theme-dark", "", components + ) + # Without or , the function returns the (theme-class-untouched) input. + assert result == user_html + + def test_complete_doc_with_modals_injects_modal_before_body_close(self): + """Modals on a complete doc get injected before .""" + from pywry.modal import Modal + from pywry.toolbar import Button + + user_html = ( + "" + "

Main

" + ) + config = WindowConfig() + content = HtmlContent(html=user_html) + modal = Modal(title="X", items=[Button(label="OK", event="m:ok")]) + result = build_html(content, config, window_label="main", modals=[modal]) + # Modal markup must appear, and it must appear before . + body_close_idx = result.lower().rfind("") + modal_idx = result.find(modal.component_id) + assert modal_idx != -1 + assert modal_idx < body_close_idx + + +class TestChatHandlersInjection: + """Tests for chat handlers JS injection (lines 791-802).""" + + def test_chat_handlers_injected_when_pywry_chat_present(self, monkeypatch): + """When 'pywry-chat' appears in HTML, the chat handlers JS is appended.""" + from pywry import templates as tmpl + + monkeypatch.setattr(tmpl, "get_chat_handlers_js", lambda: "function initChatHandlers(){}") + config = WindowConfig() + content = HtmlContent(html='
') + result = tmpl.build_html(content, config, window_label="main") + assert "function initChatHandlers(){}" in result + assert "window.initChatHandlers = initChatHandlers" in result + assert "initChatHandlers(document,window.pywry)" in result + + def test_chat_handlers_skipped_when_no_chat_class(self, monkeypatch): + """Without 'pywry-chat' in HTML, chat handlers JS is not injected.""" + from pywry import templates as tmpl + + monkeypatch.setattr(tmpl, "get_chat_handlers_js", lambda: "SENTINEL_CHAT_JS") + config = WindowConfig() + content = HtmlContent(html="
no chat here
") + result = tmpl.build_html(content, config, window_label="main") + assert "SENTINEL_CHAT_JS" not in result + + def test_chat_handlers_skipped_when_no_js_bundled(self, monkeypatch): + """When pywry-chat present but get_chat_handlers_js empty, the chat init wrapper is not added.""" + from pywry import templates as tmpl + + monkeypatch.setattr(tmpl, "get_chat_handlers_js", lambda: "") + config = WindowConfig() + content = HtmlContent(html='
') + result = tmpl.build_html(content, config, window_label="main") + # The chat-init wrapper added by build_html (assigning to window.initChatHandlers) + # only appears when the JS is bundled — without it that exact wrapper is absent. + assert "window.initChatHandlers = initChatHandlers" not in result + + +class TestBuildContentUpdateScript: + """Tests for build_content_update_script (lines 828-829).""" + + def test_includes_escaped_html_payload(self): + """The function returns JS that injects the (JSON-escaped) html into the container.""" + from pywry.templates import build_content_update_script + + result = build_content_update_script("

hello

") + # JSON-escaped content (with literal backslashes for quotes) appears in the script. + assert '"

hello<\\/p>"' in result or '"

hello

"' in result + assert "pywry-container" in result + + def test_preserves_unicode_without_escaping(self): + """ensure_ascii=False means emoji/unicode are preserved literally.""" + from pywry.templates import build_content_update_script + + result = build_content_update_script("

café

") + assert "café" in result + + def test_initialises_toolbar_and_chat_handlers(self): + """Generated script calls initToolbarHandlers and initChatHandlers if present.""" + from pywry.templates import build_content_update_script + + result = build_content_update_script("

x

") + assert "initToolbarHandlers" in result + assert "initChatHandlers" in result diff --git a/pywry/tests/test_toolbar.py b/pywry/tests/test_toolbar.py index 2b9662f..266bcd3 100644 --- a/pywry/tests/test_toolbar.py +++ b/pywry/tests/test_toolbar.py @@ -5413,3 +5413,772 @@ def test_all_types_have_component_id(self) -> None: assert item.component_id.startswith(item.type), ( f"{item.type} component_id doesn't start with type" ) + + +# ============================================================================= +# Invalid Toolbar Model Tests +# ============================================================================= + + +class TestInvalidToolbarModels: + """Tests for invalid Toolbar model values.""" + + def test_button_empty_label_uses_default(self) -> None: + """Empty button label uses 'Button' default in HTML.""" + btn = Button(label="", event="toolbar:click") + html = btn.build_html() + assert "Button" in html # Default label used + + def test_select_empty_options_allowed(self) -> None: + """Select without options is allowed.""" + sel = Select(event="view:change", options=[]) + assert len(sel.options) == 0 + + def test_slider_min_greater_than_max_raises(self) -> None: + """Slider with min > max raises validation error.""" + with pytest.raises(ValidationError): + SliderInput(event="zoom:level", min=100, max=0) + + def test_slider_value_out_of_range_raises(self) -> None: + """Slider with value outside min/max raises validation error.""" + with pytest.raises(ValidationError): + SliderInput(event="zoom:level", value=150, min=0, max=100) + + def test_range_start_greater_than_end_raises(self) -> None: + """Range with start > end raises validation error.""" + with pytest.raises(ValidationError): + RangeInput(event="filter:range", start=100, end=0) + + def test_range_min_greater_than_max_raises(self) -> None: + """Range with min > max raises validation error.""" + with pytest.raises(ValidationError): + RangeInput(event="filter:range", min=100, max=0) + + def test_range_start_out_of_range_raises(self) -> None: + """Range with start outside min/max raises validation error.""" + with pytest.raises(ValidationError): + RangeInput(event="filter:range", start=-50, min=0, max=100) + + +# ============================================================================= +# Coverage backfill tests below: each covers a specific missing source line. +# ============================================================================= + + +class TestResolveScriptContent: + """Tests for _resolve_script_content (line 307).""" + + def test_existing_file_path_is_read(self, tmp_path) -> None: + """When script Path exists, its text content is returned (line 307).""" + from pywry.toolbar import _resolve_script_content + + script_file = tmp_path / "snippet.js" + script_file.write_text("console.log('hi');") + result = _resolve_script_content(script_file) + assert result == "console.log('hi');" + + def test_existing_filelike_string_is_read(self, tmp_path) -> None: + """A string that points at an existing file is read from disk.""" + from pywry.toolbar import _resolve_script_content + + script_file = tmp_path / "alert.js" + script_file.write_text("alert('hi');") + # Strings that don't look like JS keywords ('(', 'function', ...) get + # treated as paths. `alert(...)` starts with 'a' which is not in the + # JS-start set so it's treated as a path; the file exists so we read it. + result = _resolve_script_content(str(script_file)) + assert result == "alert('hi');" + + def test_inline_script_is_returned_verbatim(self) -> None: + """Inline JS-looking strings are passed through unchanged.""" + from pywry.toolbar import _resolve_script_content + + # Starts with 'function' so it's treated as inline JS. + result = _resolve_script_content("function init(){}") + assert result == "function init(){}" + + def test_empty_script_returns_none(self) -> None: + """Empty string returns None.""" + from pywry.toolbar import _resolve_script_content + + assert _resolve_script_content("") is None + assert _resolve_script_content(None) is None + + +class TestReservedNamespaceAllowedPatterns: + """Tests for reserved-namespace ALLOWED_RESERVED_PATTERNS (lines 396-397).""" + + def test_pywry_prefixed_event_allowed(self) -> None: + """Events matching an ALLOWED_RESERVED_PATTERN bypass the reserved-namespace block.""" + from pywry.toolbar import ( + ALLOWED_RESERVED_PATTERNS, + Button, + RESERVED_NAMESPACES, + ) + + # Find an event in the allowed patterns that uses a reserved namespace. + # ALLOWED_RESERVED_PATTERNS is a sequence of prefixes like "pywry:alert". + assert ALLOWED_RESERVED_PATTERNS, "expected at least one allowed pattern" + pattern = ALLOWED_RESERVED_PATTERNS[0] + # That pattern must use a reserved namespace, otherwise the test is moot. + assert pattern.split(":")[0] in RESERVED_NAMESPACES + # Construct a button whose event begins with that prefix. + btn = Button(label="Notify", event=pattern) + assert btn.event == pattern + + +class TestToolbarItemBaseRaisesNotImplemented: + """Tests for ToolbarItem.build_html (line 414).""" + + def test_base_build_html_raises(self) -> None: + from pywry.toolbar import ToolbarItem + + # Construct a minimal base instance. + item = ToolbarItem(event="ns:event") + with pytest.raises(NotImplementedError): + item.build_html() + + +class TestSelectOptionInvalidType: + """Tests for Select option normalization (line 528).""" + + def test_select_invalid_option_type_raises_typeerror(self) -> None: + from pywry.toolbar import Option, Select + + with pytest.raises((TypeError, ValidationError)): + Select(event="theme:change", options=[Option(label="A"), 42]) + + +class TestSelectSearchableSearchInput: + """Search-enabled Select embeds a SearchInput in its dropdown header (lines 555-556).""" + + def test_searchable_select_renders_search_header(self) -> None: + from pywry.toolbar import Option, Select + + sel = Select( + event="theme:change", + options=[Option(label="Dark", value="dark"), Option(label="Light", value="light")], + searchable=True, + ) + html_out = sel.build_html() + # The select-header wrapper from the searchable branch must be present. + assert "pywry-select-header" in html_out + assert "pywry-search-wrapper" in html_out + + +class TestMultiSelectEdgeCases: + """Tests for MultiSelect normalization and display text branches (lines 623, 627, 652).""" + + def test_multiselect_dict_option(self) -> None: + """A dict option is normalized into Option (line 623).""" + from pywry.toolbar import MultiSelect + + ms = MultiSelect( + event="filter:multi", + options=[{"label": "A", "value": "a"}, {"label": "B", "value": "b"}], + ) + assert len(ms.options) == 2 + assert ms.options[0].label == "A" + + def test_multiselect_string_option(self) -> None: + """A string option is normalized into Option(label=opt, value=opt).""" + from pywry.toolbar import MultiSelect + + ms = MultiSelect(event="filter:multi", options=["X", "Y"]) + assert ms.options[0].label == "X" + assert ms.options[0].value == "X" + + def test_multiselect_invalid_option_type_raises(self) -> None: + """Invalid option type raises TypeError (line 627).""" + from pywry.toolbar import MultiSelect + + with pytest.raises((TypeError, ValidationError)): + MultiSelect(event="filter:multi", options=[42]) + + def test_multiselect_many_selected_shows_count(self) -> None: + """More than 2 selected values renders 'N selected' (line 652).""" + from pywry.toolbar import MultiSelect + + ms = MultiSelect( + event="filter:multi", + options=["A", "B", "C", "D"], + selected=["A", "B", "C"], + ) + html_out = ms.build_html() + assert "3 selected" in html_out + + +class TestNumberInputStepAttribute: + """NumberInput with step attribute writes step="..." (line 1548).""" + + def test_step_appears_in_html(self) -> None: + from pywry.toolbar import NumberInput + + n = NumberInput(event="num:val", step=0.25) + assert 'step="0.25"' in n.build_html() + + +class TestSliderInputLabelBranch: + """SliderInput with a label wraps in pywry-input-group (line 1702).""" + + def test_label_wraps_input_group(self) -> None: + from pywry.toolbar import SliderInput + + s = SliderInput(label="Zoom", event="zoom:level") + html_out = s.build_html() + assert "pywry-input-group" in html_out + assert ">Zoom<" in html_out + + +class TestRangeInputLabelBranch: + """RangeInput with a label wraps in pywry-input-group (line 1850).""" + + def test_label_wraps_input_group(self) -> None: + from pywry.toolbar import RangeInput + + r = RangeInput(label="Range", event="filter:range") + html_out = r.build_html() + assert "pywry-input-group" in html_out + assert ">Range<" in html_out + + +class TestCheckboxStyleWrapper: + """Checkbox with style is wrapped in a span (line 1942).""" + + def test_style_wrapper(self) -> None: + from pywry.toolbar import Checkbox + + c = Checkbox(event="set:check", label="On", style="margin: 8px;") + html_out = c.build_html() + assert 'style="margin: 8px;"' in html_out + # Wrapper appears: outer + assert html_out.startswith(" None: + from pywry.toolbar import RadioGroup + + with pytest.raises((TypeError, ValidationError)): + RadioGroup(event="view:change", options=[42]) + + +class TestTabGroupInvalidOption: + """TabGroup invalid option type raises (line 2093).""" + + def test_invalid_option(self) -> None: + from pywry.toolbar import TabGroup + + with pytest.raises((TypeError, ValidationError)): + TabGroup(event="view:change", options=[42]) + + +class TestDivWithScript: + """Div with a script renders a " in html_out + + +class TestTickerItemHtmlContent: + """TickerItem.update_payload with html_content populates 'html' key (line 2386).""" + + def test_html_content_in_payload(self) -> None: + from pywry.toolbar import TickerItem + + item = TickerItem(ticker="AAPL", text="AAPL") + event, data = item.update_payload(html_content="AAPL") + assert event == "toolbar:marquee-set-item" + assert data["html"] == "AAPL" + + +class TestMarqueeBuildHtmlBranches: + """Marquee build_html branches: parent_id, static items, title_attr, separator, etc.""" + + def test_disabled_marquee_includes_disabled_class(self) -> None: + """A disabled Marquee adds the pywry-disabled class (line 2568).""" + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="Hello", disabled=True) + html_out = m.build_html() + assert "pywry-disabled" in html_out + + def test_marquee_with_parent_id_attr(self) -> None: + """Passing parent_id adds data-parent-id (line 2602).""" + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="Hello") + html_out = m.build_html(parent_id="parent-1") + assert 'data-parent-id="parent-1"' in html_out + + def test_marquee_static_with_items_emits_data_attrs(self) -> None: + """Static behavior with items renders data-items/data-speed (lines 2606-2608).""" + from pywry.toolbar import Marquee + + m = Marquee( + event="ns:marq", + text="", + behavior="static", + items=["A", "B", "C"], + speed=5.0, + ) + html_out = m.build_html() + assert "data-items=" in html_out + assert 'data-speed="5.0"' in html_out + + def test_marquee_with_description_emits_data_tooltip(self) -> None: + """description renders the title-attr branch (line 2612).""" + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="Hello", description="Tip") + html_out = m.build_html() + assert "data-tooltip=" in html_out + assert "Tip" in html_out + + def test_marquee_static_single_track(self) -> None: + """Static behavior renders a single track span (line 2630).""" + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="Hello", behavior="static") + html_out = m.build_html() + # Static path renders a single content span. + assert html_out.count('class="pywry-marquee-content"') == 1 + + +class TestMarqueeCollectScripts: + """Marquee.collect_scripts collects from nested Div/Marquee children (lines 2666-2671).""" + + def test_collects_div_child_scripts(self) -> None: + from pywry.toolbar import Div, Marquee + + m = Marquee( + event="ns:marq", + text="", + children=[Div(label="", event="ns:d", script="function a(){}")], + ) + scripts = m.collect_scripts() + assert any("function a()" in s for s in scripts) + + def test_collects_nested_marquee_scripts(self) -> None: + from pywry.toolbar import Div, Marquee + + inner = Marquee( + event="ns:inner", text="", + children=[Div(label="", event="ns:d", script="function b(){}")], + ) + outer = Marquee(event="ns:outer", text="", children=[inner]) + scripts = outer.collect_scripts() + assert any("function b()" in s for s in scripts) + + +class TestTickerItemClassRemove: + """TickerItem.update_payload with class_remove (line 2718).""" + + def test_class_remove(self) -> None: + from pywry.toolbar import TickerItem + + ti = TickerItem(ticker="AAPL", text="AAPL") + _, data = ti.update_payload(class_remove=["stock-down"]) + assert data["class_remove"] == ["stock-down"] + + +class TestToolbarInvalidItemTypes: + """Toolbar normalize_items raises for invalid items (line 2843).""" + + def test_invalid_item_type_raises(self) -> None: + from pywry.toolbar import Toolbar + + with pytest.raises((TypeError, ValidationError)): + Toolbar(items=[42]) + + +class TestToolbarBuildHtmlWithScript: + """Toolbar build_html with script tag (line 2912).""" + + def test_toolbar_with_inline_script_includes_script_tag(self) -> None: + from pywry.toolbar import Button, Toolbar + + tb = Toolbar( + items=[Button(label="Go", event="app:go")], + script="function init(){}", + ) + html_out = tb.build_html() + assert "" in html_out + + +class TestSecretInputBuildHtmlBranches: + """Specific SecretInput build_html paths (lines 911, 1124, 1168, 1241-1243).""" + + def test_secret_input_no_value_renders_empty(self) -> None: + """When value is empty, no mask is shown and value attribute is empty.""" + from pywry.toolbar import SecretInput + + si = SecretInput(event="auth:key") # no value + html_out = si.build_html() + # No data-has-value attribute when there is no value. + assert "data-has-value=" not in html_out + # Type is still password (masked). + assert "type='password'" in html_out or 'type="password"' in html_out + + def test_wrap_handler_for_get_falls_back_to_registry(self) -> None: + """When handler is None, _wrap_handler_for_get returns from registry (line 1168).""" + from pywry.toolbar import SecretInput + + si = SecretInput(event="auth:key", value="hunter2") + # Call _wrap_handler_for_get when no custom handler is configured. + # has_value is True so the registry is populated by register(). + si.register() + result = si._wrap_handler_for_get({}) + assert result == "hunter2" + + def test_get_secret_value_no_handler_no_value(self) -> None: + """get_secret_value with no handler and no stored value returns None (lines 1241-1243).""" + from pywry.toolbar import SecretInput + + si = SecretInput(event="auth:key") # No value, no handler. + assert si.get_secret_value() is None + + def test_get_secret_value_no_handler_uses_internal(self) -> None: + """get_secret_value with no handler returns the internal secret string.""" + from pywry.toolbar import SecretInput + + si = SecretInput(event="auth:key", value="hunter2") + assert si.get_secret_value() == "hunter2" + + +class TestSecretHandlerFactories: + """Tests for the secret reveal/copy/update handler factories (3078-3085, 3103-3107).""" + + def test_copy_handler_dispatches_response(self) -> None: + """make_copy_handler dispatches a copy-response with encoded value (3078-3085).""" + from pywry.toolbar import ( + SecretInput, + create_default_secret_handlers, + ) + + dispatched: list[tuple[str, dict]] = [] + + def dispatch(event_type: str, data: dict) -> None: + dispatched.append((event_type, data)) + + factories = create_default_secret_handlers(dispatch) + copy_handler = factories["copy"]("auth:key") + + # Set up a SecretInput so the registry has a value to copy. + si = SecretInput(event="auth:key", value="secret-value") + si.register() + + copy_handler({"componentId": si.component_id}, "auth:key:copy", "") + assert len(dispatched) == 1 + evt, data = dispatched[0] + assert evt == "auth:key:copy-response" + assert data["componentId"] == si.component_id + # Value is base64-encoded; decode and check it round-trips. + import base64 + + decoded = base64.b64decode(data["value"]).decode("utf-8") + assert decoded == "secret-value" + assert data["encoded"] is True + + def test_update_handler_updates_secret(self) -> None: + """make_update_handler decodes encoded values and updates the SecretInput (3103-3107).""" + from pywry.toolbar import ( + SecretInput, + create_default_secret_handlers, + encode_secret, + ) + + factories = create_default_secret_handlers(lambda *_: None) + si = SecretInput(event="auth:key", value="old") + handler = factories["update"](si) + + encoded = encode_secret("new-value") + handler( + {"value": encoded, "encoded": True, "componentId": si.component_id}, + "auth:key", + "", + ) + # The internal value should now reflect the decoded plaintext. + assert si.get_secret_value() == "new-value" + + def test_update_handler_ignores_payload_without_value(self) -> None: + """update handler returns early when data has no value (line 3103-3104).""" + from pywry.toolbar import SecretInput, create_default_secret_handlers + + factories = create_default_secret_handlers(lambda *_: None) + si = SecretInput(event="auth:key", value="keep") + handler = factories["update"](si) + handler({}, "auth:key", "") + # Internal value unchanged. + assert si.get_secret_value() == "keep" + + +class TestGetToolbarHandlersJsMissing: + """RuntimeError when toolbar-handlers.js cannot be loaded (line 3237).""" + + def test_missing_handlers_js_raises(self, monkeypatch, tmp_path) -> None: + from pywry import toolbar as toolbar_module + + monkeypatch.setattr(toolbar_module, "_SRC_DIR", tmp_path) + # Bust the lru_cache from earlier successful loads. + toolbar_module._get_toolbar_handlers_js.cache_clear() + toolbar_module._get_toolbar_script_content.cache_clear() + try: + with pytest.raises(RuntimeError, match="Toolbar handlers JS not found"): + toolbar_module._get_toolbar_handlers_js() + finally: + toolbar_module._get_toolbar_handlers_js.cache_clear() + toolbar_module._get_toolbar_script_content.cache_clear() + + +class TestGetToolbarScriptWithoutTag: + """get_toolbar_script(with_script_tag=False) returns raw JS (line 3333).""" + + def test_raw_js_returned(self) -> None: + from pywry.toolbar import get_toolbar_script + + raw = get_toolbar_script(with_script_tag=False) + # Returned value is raw script content (no ") + + +class TestWrapContentWithToolbarsBranches: + """wrap_content_with_toolbars: dict-like callable toolbar config (3394-3395, 3402).""" + + def test_dict_with_position_and_items_produces_wrapper(self) -> None: + from pywry.toolbar import wrap_content_with_toolbars + + wrapped = wrap_content_with_toolbars( + "
main
", + toolbars=[{"position": "top", "items": [{"type": "button", "label": "X", "event": "ns:x"}]}], + ) + assert "pywry-wrapper-top" in wrapped + assert "main" in wrapped + + def test_dict_without_items_is_skipped(self) -> None: + """A dict with no items produces no html and is skipped (line 3402).""" + from pywry.toolbar import wrap_content_with_toolbars + + wrapped = wrap_content_with_toolbars( + "
main
", + toolbars=[{"position": "top", "items": []}], + ) + # No top wrapper because items were empty. + assert "pywry-wrapper-top" not in wrapped + assert "main" in wrapped + + def test_object_with_build_html_used(self) -> None: + """Toolbar-like object exposing build_html and position is used (lines 3393-3395).""" + from pywry.toolbar import wrap_content_with_toolbars + + class FakeToolbar: + position = "top" + + def build_html(self) -> str: + return "
X
" + + wrapped = wrap_content_with_toolbars( + "
main
", toolbars=[FakeToolbar()] + ) + assert "custom-toolbar" in wrapped + # Top wrapper applied. + assert "pywry-wrapper-top" in wrapped + + +class TestRegisterSecretHandlersForToolbar: + """register_secret_handlers_for_toolbar wires reveal/copy/update events end-to-end.""" + + def test_registers_all_three_events_per_secret(self) -> None: + from pywry.toolbar import ( + SecretInput, + Toolbar, + register_secret_handlers_for_toolbar, + ) + + si = SecretInput(event="auth:key", value="secret-value") + tb = Toolbar(items=[si]) + + registered_events: list[str] = [] + + def on_func(event_type: str, _handler) -> bool: + registered_events.append(event_type) + return True + + dispatched: list[tuple[str, dict]] = [] + + def dispatch_func(event_type: str, data: dict) -> None: + dispatched.append((event_type, data)) + + result = register_secret_handlers_for_toolbar(tb, on_func, dispatch_func) + + # The base, reveal, and copy events should all be registered. + assert "auth:key" in registered_events + assert "auth:key:reveal" in registered_events + assert "auth:key:copy" in registered_events + assert set(result) == {"auth:key", "auth:key:reveal", "auth:key:copy"} + + +class TestLabelBranchesOnSimpleInputs: + """Cover the `if self.label` branches for inputs that lacked label coverage.""" + + def test_multiselect_with_label_wraps(self) -> None: + """MultiSelect with a label is wrapped in pywry-input-group (line 702).""" + from pywry.toolbar import MultiSelect + + ms = MultiSelect(label="Filter:", event="filter:multi", options=["a", "b"]) + html_out = ms.build_html() + assert "pywry-input-group" in html_out + assert ">Filter:<" in html_out + + def test_textinput_with_label_wraps(self) -> None: + """TextInput with a label uses the labelled wrapper (line 780).""" + from pywry.toolbar import TextInput + + ti = TextInput(label="Name:", event="form:name") + html_out = ti.build_html() + assert "pywry-input-group" in html_out + assert ">Name:<" in html_out + + def test_numberinput_with_label_wraps(self) -> None: + """NumberInput with a label uses the labelled wrapper (line 1569).""" + from pywry.toolbar import NumberInput + + ni = NumberInput(label="Qty:", event="form:qty") + html_out = ni.build_html() + assert "pywry-input-group" in html_out + assert ">Qty:<" in html_out + + def test_dateinput_with_min_and_max(self) -> None: + """DateInput with min/max emits both attributes (lines 1617, 1619).""" + from pywry.toolbar import DateInput + + di = DateInput(event="form:date", min="2024-01-01", max="2024-12-31") + html_out = di.build_html() + assert 'min="2024-01-01"' in html_out + assert 'max="2024-12-31"' in html_out + + def test_dateinput_with_label_wraps(self) -> None: + """DateInput with a label uses the labelled wrapper (line 1624).""" + from pywry.toolbar import DateInput + + di = DateInput(label="Start:", event="form:date") + html_out = di.build_html() + assert "pywry-input-group" in html_out + assert ">Start:<" in html_out + + def test_tabgroup_with_label_wraps(self) -> None: + """TabGroup with a label uses the labelled wrapper (line 2139).""" + from pywry.toolbar import TabGroup + + tg = TabGroup(label="View:", event="view:change", options=["A", "B"]) + html_out = tg.build_html() + assert "pywry-input-group" in html_out + assert ">View:<" in html_out + + +class TestRangeInputEndOutOfRange: + """RangeInput with end outside [min,max] raises (line 1769).""" + + def test_end_out_of_range_raises(self) -> None: + from pywry.toolbar import RangeInput + + with pytest.raises(ValidationError): + RangeInput(event="filter:range", start=0, end=200, min=0, max=100) + + +class TestTextAreaWithStyle: + """TextArea with a style merges it into inline style (line 1332).""" + + def test_textarea_style_in_inline_style(self) -> None: + from pywry.toolbar import TextArea + + ta = TextArea(event="form:notes", style="border: 1px solid red") + html_out = ta.build_html() + # The style attribute is HTML-escaped in the rendered output. + assert "border: 1px solid red" in html_out + + +class TestSecretInputCoerceValueAlreadySecret: + """SecretInput accepts an existing SecretStr unchanged (line 911).""" + + def test_already_secret_returns_self(self) -> None: + from pydantic import SecretStr + + from pywry.toolbar import SecretInput + + s = SecretStr("hunter2") + si = SecretInput(event="auth:key", value=s) + # The internal value should be the same SecretStr instance. + assert si.value is s + + +class TestMarqueeEmptyChildrenAndNestedMarquee: + """Marquee._build_children_html branches (lines 2516, 2526).""" + + def test_marquee_build_children_html_with_no_children(self) -> None: + """Direct call to _build_children_html with no children returns '' (line 2516). + + build_html() guards this defensively by only calling _build_children_html + when children is truthy, so we exercise the guard directly. + """ + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="Hello") + assert m._build_children_html() == "" + + def test_marquee_with_empty_children_uses_text(self) -> None: + """When children is an empty list, the text path is used.""" + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="Hello", children=[]) + # text gets rendered. + assert "Hello" in m.build_html() + + def test_marquee_with_nested_marquee_child(self) -> None: + """Nested Marquee child receives parent_id (line 2526).""" + from pywry.toolbar import Marquee + + inner = Marquee(event="ns:inner", text="Inner") + outer = Marquee(event="ns:outer", text="", children=[inner]) + html_out = outer.build_html() + assert "Inner" in html_out + # The inner marquee gets a data-parent-id pointing at the outer. + assert f'data-parent-id="{outer.component_id}"' in html_out + + +class TestMarqueeWithStyle: + """Marquee with style appends to inline style parts (line 2592).""" + + def test_style_in_inline_style(self) -> None: + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="x", style="color: red") + html_out = m.build_html() + # The style attribute is composed as `--pywry-marquee-...; color: red`. + assert "color: red" in html_out + + +class TestMarqueeUpdatePayloadHtmlContent: + """Marquee.update_payload with html_content (line 2718).""" + + def test_html_content_in_payload(self) -> None: + from pywry.toolbar import Marquee + + m = Marquee(event="ns:marq", text="orig") + evt, data = m.update_payload(html_content="bold") + assert evt == "toolbar:marquee-set-content" + assert data["html"] == "bold" diff --git a/pywry/tests/test_tvchart.py b/pywry/tests/test_tvchart.py deleted file mode 100644 index d518bd9..0000000 --- a/pywry/tests/test_tvchart.py +++ /dev/null @@ -1,3050 +0,0 @@ -"""Tests for TradingView Lightweight Charts integration. - -Tests: -- TVChartConfig and sub-model serialization (snake_case to camelCase) -- normalize_ohlcv() for list-of-dicts, dict-of-lists, edge cases -- TVChartData properties (bars, volume, series_ids, total_rows) -- _serialize_timestamp() and _serialize_ohlcv_value() helpers -- _resolve_ohlcv_columns() alias resolution -- build_tvchart_toolbars() factory validation -- TVChartStateMixin event emission -- PyWryTVChartWidget class shape -- show_tvchart function signature and basic wiring -- Public API imports from pywry.__init__ -""" - -from __future__ import annotations - -import json - -from datetime import datetime, timezone -from typing import Any - -import pytest - -from pywry.tvchart import ( - TVChartBar, - TVChartData, - TVChartDatafeedBarUpdate, - TVChartDatafeedConfigRequest, - TVChartDatafeedConfigResponse, - TVChartDatafeedConfiguration, - TVChartDatafeedHistoryRequest, - TVChartDatafeedHistoryResponse, - TVChartDatafeedMarksRequest, - TVChartDatafeedMarksResponse, - TVChartDatafeedResolveRequest, - TVChartDatafeedResolveResponse, - TVChartDatafeedSearchRequest, - TVChartDatafeedSearchResponse, - TVChartDatafeedServerTimeRequest, - TVChartDatafeedServerTimeResponse, - TVChartDatafeedSubscribeRequest, - TVChartDatafeedSymbolType, - TVChartDatafeedTimescaleMarksRequest, - TVChartDatafeedTimescaleMarksResponse, - TVChartDatafeedUnsubscribeRequest, - TVChartExchange, - TVChartLibrarySubsessionInfo, - TVChartMark, - TVChartSearchSymbolResultItem, - TVChartSeriesData, - TVChartStateMixin, - TVChartSymbolInfo, - TVChartSymbolInfoPriceSource, - TVChartTimescaleMark, - build_tvchart_toolbars, - normalize_ohlcv, -) -from pywry.tvchart.config import ( - CrosshairConfig, - CrosshairMode, - LayoutConfig, - PriceScaleConfig, - PriceScaleMode, - SeriesConfig, - SeriesType, - TimeScaleConfig, - TVChartConfig, - WatermarkConfig, -) -from pywry.tvchart.normalize import ( - _resolve_ohlcv_columns, - _serialize_bar, - _serialize_ohlcv_value, - _serialize_series_from_rows, - _serialize_timestamp, -) - - -# ============================================================================= -# Fixtures -# ============================================================================= - - -def _make_ohlcv_rows(n: int = 5) -> list[dict[str, Any]]: - """Generate synthetic OHLCV rows with Unix timestamps.""" - base_time = 1_700_000_000 - return [ - { - "time": base_time + i * 86400, - "open": 100.0 + i, - "high": 105.0 + i, - "low": 98.0 + i, - "close": 103.0 + i, - "volume": 1_000_000 + i * 10_000, - } - for i in range(n) - ] - - -# ============================================================================= -# TVChartConfig serialization tests -# ============================================================================= - - -class TestTVChartConfigSerialization: - """Verify Pydantic models serialize snake_case to camelCase.""" - - def test_series_type_enum_values(self): - assert SeriesType.CANDLESTICK.value == "Candlestick" - assert SeriesType.LINE.value == "Line" - assert SeriesType.AREA.value == "Area" - assert SeriesType.HISTOGRAM.value == "Histogram" - - def test_series_config_camel_case(self): - cfg = SeriesConfig( - series_type=SeriesType.LINE, - price_scale_id="left", - up_color="#00ff00", - down_color="#ff0000", - ) - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert "seriesType" in d - assert d["seriesType"] == "Line" - assert "priceScaleId" in d - assert d["priceScaleId"] == "left" - assert "upColor" in d - assert "downColor" in d - # No snake_case keys - assert "series_type" not in d - assert "price_scale_id" not in d - - def test_price_scale_config_defaults(self): - cfg = PriceScaleConfig() - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert d["autoScale"] is True - assert d["mode"] == PriceScaleMode.NORMAL.value - - def test_time_scale_config_defaults(self): - cfg = TimeScaleConfig() - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert d["rightOffset"] == 5 - assert d["barSpacing"] == 6.0 - assert d["timeVisible"] is True - - def test_crosshair_config_defaults(self): - cfg = CrosshairConfig() - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert d["mode"] == CrosshairMode.MAGNET.value - - def test_layout_config_camel_case(self): - cfg = LayoutConfig(text_color="#fff", font_size=14, font_family="Arial") - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert "textColor" in d - assert "fontSize" in d - assert "fontFamily" in d - assert "text_color" not in d - - def test_watermark_config_camel_case(self): - cfg = WatermarkConfig(visible=True, text="AAPL", font_size=64) - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert d["visible"] is True - assert d["fontSize"] == 64 - assert "horzAlign" in d - assert "vertAlign" in d - - def test_tvchart_config_nested(self): - cfg = TVChartConfig( - time_scale=TimeScaleConfig(right_offset=10), - crosshair=CrosshairConfig(mode=CrosshairMode.NORMAL), - ) - d = cfg.model_dump(by_alias=True, exclude_none=True) - assert "timeScale" in d - assert d["timeScale"]["rightOffset"] == 10 - assert "crosshair" in d - assert d["crosshair"]["mode"] == CrosshairMode.NORMAL.value - - def test_tvchart_config_json_roundtrip(self): - cfg = TVChartConfig( - time_scale=TimeScaleConfig(right_offset=10), - layout=LayoutConfig(text_color="#ccc"), - ) - json_str = cfg.model_dump_json(by_alias=True, exclude_none=True) - parsed = json.loads(json_str) - assert "timeScale" in parsed - assert parsed["layout"]["textColor"] == "#ccc" - - -# ============================================================================= -# Timestamp serialization tests -# ============================================================================= - - -class TestSerializeTimestamp: - """Test the _serialize_timestamp helper function.""" - - def test_int(self): - assert _serialize_timestamp(1_700_000_000) == 1_700_000_000 - - def test_float(self): - assert _serialize_timestamp(1_700_000_000.5) == 1_700_000_000 - - def test_nan_returns_none(self): - assert _serialize_timestamp(float("nan")) is None - - def test_inf_returns_none(self): - assert _serialize_timestamp(float("inf")) is None - - def test_none_returns_none(self): - assert _serialize_timestamp(None) is None - - def test_datetime_utc(self): - dt = datetime(2023, 11, 15, 0, 0, 0, tzinfo=timezone.utc) - result = _serialize_timestamp(dt) - assert result == int(dt.timestamp()) - - def test_datetime_naive(self): - dt = datetime(2023, 11, 15, 0, 0, 0) - result = _serialize_timestamp(dt) - assert result == int(dt.replace(tzinfo=timezone.utc).timestamp()) - - def test_iso_string(self): - result = _serialize_timestamp("2023-11-15T00:00:00") - expected = int(datetime(2023, 11, 15, tzinfo=timezone.utc).timestamp()) - assert result == expected - - -# ============================================================================= -# OHLCV value serialization tests -# ============================================================================= - - -class TestSerializeOHLCVValue: - """Test the _serialize_ohlcv_value helper function.""" - - def test_float(self): - assert _serialize_ohlcv_value(100.5) == 100.5 - - def test_int(self): - assert _serialize_ohlcv_value(100) == 100.0 - - def test_nan_returns_none(self): - assert _serialize_ohlcv_value(float("nan")) is None - - def test_inf_returns_none(self): - assert _serialize_ohlcv_value(float("inf")) is None - - def test_none_returns_none(self): - assert _serialize_ohlcv_value(None) is None - - def test_string_number(self): - assert _serialize_ohlcv_value("42.5") == 42.5 - - def test_invalid_string_returns_none(self): - assert _serialize_ohlcv_value("not_a_number") is None - - -# ============================================================================= -# Column alias resolution tests -# ============================================================================= - - -class TestResolveOHLCVColumns: - """Test column name alias resolution.""" - - def test_standard_lowercase(self): - cols = ["time", "open", "high", "low", "close", "volume"] - result = _resolve_ohlcv_columns(cols) - assert result["time"] == "time" - assert result["open"] == "open" - assert result["high"] == "high" - assert result["low"] == "low" - assert result["close"] == "close" - assert result["volume"] == "volume" - - def test_capitalized(self): - cols = ["Date", "Open", "High", "Low", "Close", "Volume"] - result = _resolve_ohlcv_columns(cols) - assert result["time"] == "Date" - assert result["open"] == "Open" - assert result["close"] == "Close" - assert result["volume"] == "Volume" - - def test_single_letter(self): - cols = ["t", "o", "h", "l", "c", "v"] - result = _resolve_ohlcv_columns(cols) - assert result["time"] == "t" - assert result["open"] == "o" - assert result["close"] == "c" - - def test_adj_close_alias(self): - cols = ["time", "adj_close"] - result = _resolve_ohlcv_columns(cols) - assert result["close"] == "adj_close" - - def test_missing_columns_are_none(self): - cols = ["time", "close"] - result = _resolve_ohlcv_columns(cols) - assert result["open"] is None - assert result["high"] is None - assert result["low"] is None - assert result["volume"] is None - - def test_timestamp_alias(self): - cols = ["Timestamp", "close"] - result = _resolve_ohlcv_columns(cols) - assert result["time"] == "Timestamp" - - -# ============================================================================= -# _serialize_bar tests -# ============================================================================= - - -class TestSerializeBar: - """Test single-row bar serialization.""" - - def test_full_ohlcv_bar(self): - row = { - "time": 1700000000, - "open": 100, - "high": 105, - "low": 98, - "close": 103, - "volume": 1000000, - } - ohlcv_map = { - "time": "time", - "open": "open", - "high": "high", - "low": "low", - "close": "close", - "volume": "volume", - } - bar, vol = _serialize_bar(row, ohlcv_map) - assert bar is not None - assert bar["time"] == 1700000000 - assert bar["open"] == 100.0 - assert bar["high"] == 105.0 - assert bar["low"] == 98.0 - assert bar["close"] == 103.0 - assert vol is not None - assert vol["time"] == 1700000000 - assert vol["value"] == 1000000.0 - - def test_line_bar_close_only(self): - row = {"time": 1700000000, "close": 103} - ohlcv_map = { - "time": "time", - "open": None, - "high": None, - "low": None, - "close": "close", - "volume": None, - } - bar, vol = _serialize_bar(row, ohlcv_map) - assert bar is not None - assert "value" in bar - assert bar["value"] == 103.0 - assert "open" not in bar - assert vol is None - - def test_missing_time_returns_none(self): - row = {"close": 103} - ohlcv_map = { - "time": "time", - "open": None, - "high": None, - "low": None, - "close": "close", - "volume": None, - } - bar, vol = _serialize_bar(row, ohlcv_map) - assert bar is None - assert vol is None - - -# ============================================================================= -# normalize_ohlcv tests -# ============================================================================= - - -class TestNormalizeOhlcv: - """Test the main normalization entry point.""" - - def test_list_of_dicts_ohlcv(self): - rows = _make_ohlcv_rows(5) - result = normalize_ohlcv(rows) - assert isinstance(result, TVChartData) - assert len(result.series) == 1 - assert result.series[0].series_id == "main" - assert len(result.bars) == 5 - assert result.series[0].has_volume is True - assert len(result.volume) == 5 - assert result.series[0].series_type == SeriesType.CANDLESTICK - - def test_list_of_dicts_line_data(self): - rows = [{"time": 1700000000 + i * 86400, "close": 100.0 + i} for i in range(3)] - result = normalize_ohlcv(rows) - assert len(result.bars) == 3 - assert "value" in result.bars[0] - assert result.series[0].series_type == SeriesType.LINE - - def test_dict_of_lists(self): - data = { - "time": [1700000000, 1700086400], - "close": [100.0, 101.0], - } - result = normalize_ohlcv(data) - assert len(result.bars) == 2 - - def test_empty_list(self): - result = normalize_ohlcv([]) - assert len(result.series) == 1 - assert result.series[0].series_id == "main" - assert len(result.bars) == 0 - - def test_passthrough_tvchart_data(self): - """normalize_ohlcv returns TVChartData unchanged.""" - original = TVChartData( - series=[ - TVChartSeriesData(series_id="test", bars=[{"time": 1, "value": 2}], total_rows=1) - ], - ) - result = normalize_ohlcv(original) - assert result is original - - def test_raises_on_unsupported_type(self): - with pytest.raises(TypeError, match="Unsupported data type"): - normalize_ohlcv("not_valid_data") - - def test_raises_on_missing_time_column(self): - rows = [{"price": 100}] - with pytest.raises(ValueError, match="Could not resolve time column"): - normalize_ohlcv(rows) - - def test_raises_on_missing_close_column(self): - rows = [{"time": 1700000000, "foo": 100}] - with pytest.raises(ValueError, match="Could not resolve close/value column"): - normalize_ohlcv(rows) - - def test_max_bars_truncation(self): - rows = _make_ohlcv_rows(20) - result = normalize_ohlcv(rows, max_bars=5) - assert len(result.bars) == 5 - assert result.series[0].truncated_rows == 15 - assert result.series[0].total_rows == 20 - - def test_capitalized_columns(self): - rows = [ - {"Date": 1700000000, "Open": 100, "High": 105, "Low": 98, "Close": 103, "Volume": 1000} - ] - result = normalize_ohlcv(rows) - assert len(result.bars) == 1 - assert result.bars[0]["open"] == 100.0 - - -# ============================================================================= -# TVChartData property tests -# ============================================================================= - - -class TestTVChartDataProperties: - """Test TVChartData computed properties.""" - - def test_bars_property(self): - s = TVChartSeriesData(series_id="main", bars=[{"time": 1, "value": 2}], total_rows=1) - data = TVChartData(series=[s]) - assert data.bars == [{"time": 1, "value": 2}] - - def test_volume_property(self): - s = TVChartSeriesData( - series_id="main", - bars=[{"time": 1, "value": 2}], - volume=[{"time": 1, "value": 1000}], - total_rows=1, - ) - data = TVChartData(series=[s]) - assert data.volume == [{"time": 1, "value": 1000}] - - def test_series_ids_property(self): - data = TVChartData( - series=[ - TVChartSeriesData(series_id="AAPL", bars=[], total_rows=0), - TVChartSeriesData(series_id="MSFT", bars=[], total_rows=0), - ] - ) - assert data.series_ids == ["AAPL", "MSFT"] - - def test_total_rows_property(self): - data = TVChartData( - series=[ - TVChartSeriesData(series_id="a", bars=[], total_rows=100), - TVChartSeriesData(series_id="b", bars=[], total_rows=200), - ] - ) - assert data.total_rows == 300 - - def test_empty_series_properties(self): - data = TVChartData(series=[]) - assert data.bars == [] - assert data.volume == [] - assert data.series_ids == [] - assert data.total_rows == 0 - - -class TestTVChartDatafeedModels: - """Verify datafeed protocol model shape and serialization.""" - - # --- TVChartSymbolInfo (full LibrarySymbolInfo) --- - - def test_symbol_info_required_fields(self): - info = TVChartSymbolInfo() - dumped = info.model_dump(exclude_none=True) - assert dumped["name"] == "" - assert dumped["description"] == "" - assert dumped["exchange"] == "" - assert dumped["listed_exchange"] == "" - assert dumped["type"] == "stock" - assert dumped["session"] == "24x7" - assert dumped["timezone"] == "Etc/UTC" - assert dumped["minmov"] == 1 - assert dumped["pricescale"] == 100 - assert dumped["format"] == "price" - - def test_symbol_info_full_fields(self): - info = TVChartSymbolInfo( - name="AAPL", - symbol="NASDAQ:AAPL", - ticker="AAPL", - full_name="Apple Inc.", - description="Apple common stock", - exchange="NASDAQ", - listed_exchange="NASDAQ", - type="stock", - session="0930-1600", - timezone="America/New_York", - currency_code="USD", - minmov=1, - pricescale=100, - format="price", - has_intraday=True, - has_daily=True, - has_weekly_and_monthly=True, - supported_resolutions=["1", "5", "15", "60", "1D", "1W", "1M"], - intraday_multipliers=["1", "5", "15", "60"], - daily_multipliers=["1"], - weekly_multipliers=["1"], - monthly_multipliers=["1"], - visible_plots_set="ohlcv", - volume_precision=0, - data_status="streaming", - sector="Technology", - industry="Consumer Electronics", - logo_urls=["https://example.com/aapl.svg"], - ) - dumped = info.model_dump(exclude_none=True) - assert dumped["name"] == "AAPL" - assert dumped["ticker"] == "AAPL" - assert dumped["full_name"] == "Apple Inc." - assert dumped["has_intraday"] is True - assert dumped["has_daily"] is True - assert dumped["visible_plots_set"] == "ohlcv" - assert dumped["sector"] == "Technology" - assert dumped["supported_resolutions"] == ["1", "5", "15", "60", "1D", "1W", "1M"] - - def test_symbol_info_optional_fields_excluded(self): - info = TVChartSymbolInfo( - name="X", description="Test", exchange="NYSE", listed_exchange="NYSE" - ) - dumped = info.model_dump(exclude_none=True) - assert "ticker" not in dumped - assert "has_seconds" not in dumped - assert "expired" not in dumped - assert "subsessions" not in dumped - - def test_symbol_info_fractional_format(self): - info = TVChartSymbolInfo( - name="ZBM2023", - description="T-Bond Futures", - exchange="CME", - listed_exchange="CME", - type="futures", - minmov=1, - pricescale=128, - minmove2=4, - fractional=True, - ) - dumped = info.model_dump(exclude_none=True) - assert dumped["fractional"] is True - assert dumped["minmove2"] == 4 - assert dumped["pricescale"] == 128 - - def test_symbol_info_subsessions(self): - info = TVChartSymbolInfo( - name="ES", - description="E-Mini S&P", - exchange="CME", - listed_exchange="CME", - session="0930-1600", - subsession_id="regular", - subsessions=[ - TVChartLibrarySubsessionInfo( - id="regular", description="Regular", session="0930-1600" - ), - TVChartLibrarySubsessionInfo( - id="extended", description="Extended", session="0400-2000" - ), - ], - ) - dumped = info.model_dump(exclude_none=True) - assert len(dumped["subsessions"]) == 2 - assert dumped["subsessions"][0]["id"] == "regular" - - def test_symbol_info_price_sources(self): - info = TVChartSymbolInfo( - name="AAPL", - description="Apple", - exchange="NASDAQ", - listed_exchange="NASDAQ", - price_sources=[ - TVChartSymbolInfoPriceSource(id="1", name="Spot Price"), - TVChartSymbolInfoPriceSource(id="2", name="Bid"), - ], - price_source_id="1", - ) - dumped = info.model_dump(exclude_none=True) - assert len(dumped["price_sources"]) == 2 - assert dumped["price_source_id"] == "1" - - def test_symbol_info_alias(self): - info = TVChartSymbolInfo( - name="X", description="", exchange="", listed_exchange="", symbol_type="futures" - ) - dumped = info.model_dump(exclude_none=True) - assert dumped["symbol_type"] == "futures" - - # --- TVChartDatafeedConfiguration --- - - def test_datafeed_configuration_model(self): - cfg = TVChartDatafeedConfiguration( - exchanges=[TVChartExchange(value="NYSE", name="New York Stock Exchange", desc="")], - symbols_types=[TVChartDatafeedSymbolType(name="Stock", value="stock")], - supported_resolutions=["1", "5", "15", "60", "1D", "1W", "1M"], - supports_marks=True, - supports_timescale_marks=True, - supports_time=True, - currency_codes=["USD", "EUR"], - ) - dumped = cfg.model_dump(exclude_none=True) - assert len(dumped["exchanges"]) == 1 - assert dumped["exchanges"][0]["value"] == "NYSE" - assert dumped["supports_marks"] is True - assert dumped["supports_time"] is True - assert "1D" in dumped["supported_resolutions"] - - def test_datafeed_configuration_empty(self): - cfg = TVChartDatafeedConfiguration() - dumped = cfg.model_dump(exclude_none=True) - assert dumped == {} - - # --- TVChartSearchSymbolResultItem --- - - def test_search_symbol_result_item(self): - item = TVChartSearchSymbolResultItem( - symbol="AAPL", - description="Apple Inc.", - exchange="NasdaqNM", - type="stock", - ticker="AAPL", - ) - dumped = item.model_dump(exclude_none=True) - assert dumped["symbol"] == "AAPL" - assert dumped["type"] == "stock" - - # --- TVChartBar --- - - def test_bar_model(self): - bar = TVChartBar( - time=1700000000000, open=100.0, high=105.0, low=99.0, close=103.0, volume=1000000.0 - ) - dumped = bar.model_dump() - assert dumped["time"] == 1700000000000 - assert dumped["open"] == 100.0 - assert dumped["volume"] == 1000000.0 - - def test_bar_model_no_volume(self): - bar = TVChartBar(time=1700000000000, open=100.0, high=105.0, low=99.0, close=103.0) - dumped = bar.model_dump(exclude_none=True) - assert "volume" not in dumped - - # --- TVChartMark --- - - def test_mark_model(self): - mark = TVChartMark( - id="m1", - time=1700000000, - color="red", - text="Earnings", - label="E", - label_font_color="white", - min_size=24, - ) - dumped = mark.model_dump(exclude_none=True) - assert dumped["id"] == "m1" - assert dumped["color"] == "red" - assert dumped["label"] == "E" - - # --- TVChartTimescaleMark --- - - def test_timescale_mark_model(self): - mark = TVChartTimescaleMark( - id="ts1", - time=1700000000, - color="blue", - label="D", - tooltip=["Dividend", "$0.25/share"], - ) - dumped = mark.model_dump(exclude_none=True) - assert dumped["id"] == "ts1" - assert dumped["tooltip"] == ["Dividend", "$0.25/share"] - - # --- Config request/response --- - - def test_config_request_response(self): - req = TVChartDatafeedConfigRequest(request_id="cfg-1", chart_id="main") - assert req.request_id == "cfg-1" - - resp = TVChartDatafeedConfigResponse( - request_id="cfg-1", - config=TVChartDatafeedConfiguration(supports_marks=True), - ) - dumped = resp.model_dump(exclude_none=True) - assert dumped["config"]["supports_marks"] is True - - # --- Search request/response --- - - def test_search_request_and_response_models(self): - req = TVChartDatafeedSearchRequest( - request_id="req-1", - query="aapl", - chart_id="main", - limit=15, - exchange="NASDAQ", - symbol_type="stock", - ) - req_dump = req.model_dump() - assert req_dump["request_id"] == "req-1" - assert req_dump["query"] == "aapl" - assert req_dump["exchange"] == "NASDAQ" - assert req_dump["limit"] == 15 - - resp = TVChartDatafeedSearchResponse( - request_id="req-1", - query="aapl", - items=[ - TVChartSearchSymbolResultItem( - symbol="AAPL", description="Apple Inc.", exchange="NASDAQ" - ) - ], - ) - resp_dump = resp.model_dump(exclude_none=True) - assert resp_dump["request_id"] == "req-1" - assert resp_dump["items"][0]["symbol"] == "AAPL" - - # --- Resolve request/response --- - - def test_resolve_request_and_response_models(self): - req = TVChartDatafeedResolveRequest( - request_id="req-2", symbol="NASDAQ:AAPL", chart_id="main" - ) - req_dump = req.model_dump() - assert req_dump["request_id"] == "req-2" - assert req_dump["symbol"] == "NASDAQ:AAPL" - - resp = TVChartDatafeedResolveResponse( - request_id="req-2", - symbol_info=TVChartSymbolInfo( - name="AAPL", - ticker="AAPL", - description="Apple", - exchange="NASDAQ", - listed_exchange="NASDAQ", - pricescale=100, - ), - ) - resp_dump = resp.model_dump(exclude_none=True) - assert resp_dump["request_id"] == "req-2" - assert resp_dump["symbol_info"]["ticker"] == "AAPL" - assert resp_dump["symbol_info"]["pricescale"] == 100 - - # --- History request/response --- - - def test_history_request_and_response_models(self): - req = TVChartDatafeedHistoryRequest( - request_id="req-3", - symbol="NASDAQ:AAPL", - resolution="1D", - from_time=1_700_000_000, - to_time=1_700_086_400, - count_back=300, - first_data_request=True, - ) - req_dump = req.model_dump() - assert req_dump["resolution"] == "1D" - assert req_dump["count_back"] == 300 - assert req_dump["first_data_request"] is True - - resp = TVChartDatafeedHistoryResponse( - request_id="req-3", - status="ok", - bars=[{"time": 1_700_000_000, "open": 1.0, "high": 2.0, "low": 0.5, "close": 1.5}], - ) - resp_dump = resp.model_dump() - assert resp_dump["request_id"] == "req-3" - assert resp_dump["status"] == "ok" - assert len(resp_dump["bars"]) == 1 - - def test_history_response_no_data(self): - resp = TVChartDatafeedHistoryResponse( - request_id="req-4", - status="no_data", - bars=[], - no_data=True, - next_time=1_699_900_000_000, - ) - dumped = resp.model_dump(exclude_none=True) - assert dumped["no_data"] is True - assert dumped["next_time"] == 1_699_900_000_000 - assert dumped["status"] == "no_data" - - # --- Subscribe/Unsubscribe --- - - def test_subscribe_request(self): - req = TVChartDatafeedSubscribeRequest( - request_id="sub-1", - symbol="AAPL", - resolution="1", - listener_guid="guid-abc-123", - chart_id="main", - ) - assert req.listener_guid == "guid-abc-123" - assert req.resolution == "1" - - def test_unsubscribe_request(self): - req = TVChartDatafeedUnsubscribeRequest( - listener_guid="guid-abc-123", - chart_id="main", - ) - assert req.listener_guid == "guid-abc-123" - - def test_bar_update(self): - update = TVChartDatafeedBarUpdate( - listener_guid="guid-abc-123", - bar={ - "time": 1700000000000, - "open": 100, - "high": 105, - "low": 99, - "close": 103, - "volume": 50000, - }, - ) - dumped = update.model_dump() - assert dumped["listener_guid"] == "guid-abc-123" - assert dumped["bar"]["close"] == 103 - - # --- Marks request/response --- - - def test_marks_request_response(self): - req = TVChartDatafeedMarksRequest( - request_id="m-1", - symbol="AAPL", - from_time=1700000000, - to_time=1700086400, - resolution="1D", - ) - assert req.from_time == 1700000000 - - resp = TVChartDatafeedMarksResponse( - request_id="m-1", - marks=[ - { - "id": "mk1", - "time": 1700000000, - "color": "red", - "text": "Buy", - "label": "B", - "labelFontColor": "#fff", - "minSize": 20, - } - ], - ) - dumped = resp.model_dump(exclude_none=True) - assert len(dumped["marks"]) == 1 - - # --- TimescaleMarks request/response --- - - def test_timescale_marks_request_response(self): - req = TVChartDatafeedTimescaleMarksRequest( - request_id="ts-1", - symbol="AAPL", - from_time=1700000000, - to_time=1700086400, - resolution="1D", - ) - assert req.resolution == "1D" - - resp = TVChartDatafeedTimescaleMarksResponse( - request_id="ts-1", - marks=[ - { - "id": "tsm1", - "time": 1700000000, - "color": "blue", - "label": "D", - "tooltip": ["Dividend"], - } - ], - ) - dumped = resp.model_dump(exclude_none=True) - assert len(dumped["marks"]) == 1 - - # --- ServerTime request/response --- - - def test_server_time_request_response(self): - req = TVChartDatafeedServerTimeRequest(request_id="st-1", chart_id="main") - assert req.request_id == "st-1" - - resp = TVChartDatafeedServerTimeResponse( - request_id="st-1", - time=1700000000, - ) - dumped = resp.model_dump(exclude_none=True) - assert dumped["time"] == 1700000000 - - -# ============================================================================= -# _serialize_series_from_rows tests -# ============================================================================= - - -class TestSerializeSeriesFromRows: - """Test bulk row serialization into TVChartSeriesData.""" - - def test_basic_ohlcv(self): - rows = _make_ohlcv_rows(3) - ohlcv_map = _resolve_ohlcv_columns(list(rows[0].keys())) - result = _serialize_series_from_rows(rows, ohlcv_map, "test") - assert result.series_id == "test" - assert len(result.bars) == 3 - assert result.series_type == SeriesType.CANDLESTICK - assert result.has_volume is True - assert result.total_rows == 3 - assert result.truncated_rows == 0 - - def test_truncation(self): - rows = _make_ohlcv_rows(10) - ohlcv_map = _resolve_ohlcv_columns(list(rows[0].keys())) - result = _serialize_series_from_rows(rows, ohlcv_map, "test", max_bars=3) - assert len(result.bars) == 3 - assert result.total_rows == 10 - assert result.truncated_rows == 7 - - -# ============================================================================= -# build_tvchart_toolbars tests -# ============================================================================= - - -class TestBuildTVChartToolbars: - """Test the toolbar factory function.""" - - def test_returns_four_toolbars(self): - toolbars = build_tvchart_toolbars() - assert len(toolbars) == 4 - - def test_toolbar_positions(self): - toolbars = build_tvchart_toolbars() - positions = [tb.position for tb in toolbars] - assert "top" in positions - assert "left" in positions - assert "bottom" in positions - assert "inside" in positions - - def test_header_has_chart_type_select(self): - toolbars = build_tvchart_toolbars() - header = next(tb for tb in toolbars if tb.position == "top") - ids = [item.component_id for item in header.items] - assert "wrap-tvchart-chart-type" in ids - - def test_left_has_drawing_tools(self): - toolbars = build_tvchart_toolbars() - left = next(tb for tb in toolbars if tb.position == "left") - ids = [item.component_id for item in left.items] - assert "wrap-tvchart-tool-crosshair" in ids - assert "wrap-tvchart-group-lines" in ids - assert "wrap-tvchart-group-channels" in ids - assert "wrap-tvchart-tool-eraser" in ids - - def test_bottom_has_time_range_tabs(self): - toolbars = build_tvchart_toolbars() - bottom = next(tb for tb in toolbars if tb.position == "bottom") - ids = [item.component_id for item in bottom.items] - assert "tvchart-time-range" in ids - assert "wrap-tvchart-date-range" in ids - - def test_bottom_uses_daily_practical_ranges_when_intraday_is_unavailable(self): - toolbars = build_tvchart_toolbars(intervals=["1d", "1w", "1M"], selected_interval="1d") - bottom = next(tb for tb in toolbars if tb.position == "bottom") - time_range = next( - item for item in bottom.items if item.component_id == "tvchart-time-range" - ) - - assert [option.value for option in time_range.options] == [ - "all", - "10y", - "5y", - "1y", - "ytd", - "6m", - "3m", - "1m", - ] - assert time_range.selected == "1y" - assert [option.data_attrs["target-interval"] for option in time_range.options] == [ - "1d", - "1M", - "1w", - "1d", - "1d", - "1d", - "1d", - "1d", - ] - - def test_bottom_uses_longer_ranges_for_weekly_only_data(self): - toolbars = build_tvchart_toolbars(intervals=["1w", "1M"], selected_interval="1w") - bottom = next(tb for tb in toolbars if tb.position == "bottom") - time_range = next( - item for item in bottom.items if item.component_id == "tvchart-time-range" - ) - - assert [option.value for option in time_range.options] == [ - "all", - "10y", - "5y", - "3y", - "1y", - "ytd", - "6m", - "3m", - ] - assert time_range.selected == "1y" - assert [option.data_attrs["target-interval"] for option in time_range.options] == [ - "1w", - "1M", - "1w", - "1w", - "1w", - "1w", - "1w", - "1w", - ] - - def test_bottom_uses_multi_year_ranges_for_quarterly_data(self): - toolbars = build_tvchart_toolbars(intervals=["3M", "12M"], selected_interval="3M") - bottom = next(tb for tb in toolbars if tb.position == "bottom") - time_range = next( - item for item in bottom.items if item.component_id == "tvchart-time-range" - ) - - assert [option.value for option in time_range.options] == [ - "all", - "20y", - "10y", - "5y", - "3y", - "ytd", - ] - assert time_range.selected == "ytd" - assert [option.data_attrs["target-interval"] for option in time_range.options] == [ - "3M", - "3M", - "3M", - "3M", - "3M", - "3M", - ] - - def test_bottom_intraday_ranges_expose_expected_target_intervals_and_tooltips(self): - toolbars = build_tvchart_toolbars( - intervals=["1m", "3m", "5m", "15m", "30m", "45m", "1h", "2h", "4h", "1d", "1w", "1M"], - selected_interval="1d", - ) - bottom = next(tb for tb in toolbars if tb.position == "bottom") - time_range = next( - item for item in bottom.items if item.component_id == "tvchart-time-range" - ) - options = {option.value: option for option in time_range.options} - - assert options["1d"].data_attrs["target-interval"] == "1m" - assert options["5d"].data_attrs["target-interval"] == "5m" - assert options["1m"].data_attrs["target-interval"] == "30m" - assert options["3m"].data_attrs["target-interval"] == "1h" - assert options["6m"].data_attrs["target-interval"] == "2h" - assert options["ytd"].data_attrs["target-interval"] == "1d" - assert options["all"].data_attrs["target-interval"] == "1d" - assert options["1d"].description == "1 day" - assert options["5d"].description == "5 days" - assert options["1m"].description == "1 month" - assert options["3m"].description == "3 months" - assert options["6m"].description == "6 months" - assert options["ytd"].description == "Year to date" - assert options["all"].description == "All" - assert options["10y"].label == "10y" - assert options["ytd"].label == "YTD" - assert options["all"].label == "Max" - - def test_header_has_save_button(self): - toolbars = build_tvchart_toolbars() - header = next(tb for tb in toolbars if tb.position == "top") - ids = [item.component_id for item in header.items] - assert "wrap-tvchart-save-split" in ids - - def test_inside_toolbar_legend_div_has_no_inline_script(self): - """Legend script is loaded via 11-legend.js, not inline on the Div.""" - toolbars = build_tvchart_toolbars() - inside = next(tb for tb in toolbars if tb.position == "inside") - legend = inside.items[0] - assert legend.script is None or legend.script == "" - - -class TestTVChartFrontendStateContracts: - """Validate structural and behavioural contracts in the JS frontend source. - - Each test scopes its assertions to a specific function or handler body, - verifying that the tested property lives in the correct execution context - rather than just checking for string presence anywhere in ~12 000 lines. - """ - - @pytest.fixture - def tvchart_defaults_js(self) -> str: - from pywry.assets import get_tvchart_defaults_js - - return get_tvchart_defaults_js() - - # ------------------------------------------------------------------ - # Helpers: scope extraction - # ------------------------------------------------------------------ - - @staticmethod - def _skip_comment(src: str, i: int, n: int) -> int | None: - """If *i* points to the start of a JS comment, return the index past it.""" - if i + 1 >= n: - return None - nxt = src[i + 1] - if nxt == "/": - nl = src.find("\n", i) - return (nl + 1) if nl != -1 else n - if nxt == "*": - end = src.find("*/", i + 2) - return (end + 2) if end != -1 else n - return None - - @staticmethod - def _extract_braced(src: str, search_from: int) -> str: - """Return text from *search_from* through the matching closing brace. - - Handles string literals (single, double, backtick) and comments - (// and /* */) so braces inside them are not counted. - """ - i = src.index("{", search_from) - depth = 0 - in_string: str | None = None - escaped = False - n = len(src) - while i < n: - ch = src[i] - if escaped: - escaped = False - i += 1 - continue - if ch == "\\": - escaped = True - i += 1 - continue - if in_string: - if ch == in_string: - in_string = None - i += 1 - continue - if ch in ("'", '"', "`"): - in_string = ch - elif ch == "/": - skip = TestTVChartFrontendStateContracts._skip_comment(src, i, n) - if skip is not None: - i = skip - continue - elif ch == "{": - depth += 1 - elif ch == "}": - depth -= 1 - if depth == 0: - return src[search_from : i + 1] - i += 1 - return src[search_from:] - - def _fn(self, src: str, name: str) -> str: - """Extract the full body of ``function (...)``.""" - return self._extract_braced(src, src.index(f"function {name}(")) - - def _handler(self, src: str, event: str) -> str: - """Extract the body of an event listener for ````. - - Accepts both ``window.pywry.on('', ...)`` and - ``bridge.on('', ...)`` — the tvchart event handlers are - registered against a local ``bridge`` reference that defaults to - ``window.pywry``. - """ - candidates = ( - f"window.pywry.on('{event}'", - f"bridge.on('{event}'", - ) - for candidate in candidates: - idx = src.find(candidate) - if idx != -1: - return self._extract_braced(src, idx) - raise ValueError(f"No handler registration found for event '{event}'") - - def _create_body(self, src: str) -> str: - """Extract the PYWRY_TVCHART_CREATE function body.""" - start = src.index("window.PYWRY_TVCHART_CREATE = function") - end = src.index("window.PYWRY_TVCHART_UPDATE", start) - return src[start:end] - - # ------------------------------------------------------------------ - # State export & request - # ------------------------------------------------------------------ - - def test_state_export_returns_all_survival_fields(self, tvchart_defaults_js: str): - """_tvExportState must return rawData, drawings, and indicators so a - chart can be fully reconstructed after a page reload.""" - body = self._fn(tvchart_defaults_js, "_tvExportState") - # The return object must include each survival-critical field - for field in ("rawData", "drawings", "indicators"): - assert f"{field}: {field}" in body, ( - f"_tvExportState return object must include '{field}'" - ) - # Also verify visibleRange is captured (for state, not layout) - assert "getVisibleLogicalRange()" in body - - def test_state_response_echoes_request_context(self, tvchart_defaults_js: str): - """The request-state handler must attach data.context to the response - so the caller can correlate responses to requests.""" - body = self._handler(tvchart_defaults_js, "tvchart:request-state") - # Must call the export function - assert "_tvExportState(" in body - # Must propagate context - assert "data.context" in body - assert "Object.assign" in body - # Must emit the response event - assert "tvchart:state-response" in body - - # ------------------------------------------------------------------ - # Legend scoping & controls - # ------------------------------------------------------------------ - - def test_legend_setup_is_scoped_to_chart_instance(self, tvchart_defaults_js: str): - """_tvSetupLegendControls must accept chartId and use scoped DOM - queries — never reference hardcoded chart IDs or global singletons.""" - body = self._fn(tvchart_defaults_js, "_tvSetupLegendControls") - # Scoped lookup pattern - assert "_tvResolveChartEntry(chartId)" in body - assert "_tvScopedById(chartId" in body - # Uses local scopedById helper for DOM queries - assert "function scopedById(id)" in body - # Never hardcodes the first chart - assert "chartIds[0]" not in body - - def test_legend_has_per_series_action_buttons(self, tvchart_defaults_js: str): - """Each legend series row must have hide, settings, remove, and more actions.""" - body = self._fn(tvchart_defaults_js, "_tvSetupLegendControls") - required_actions = ["hide", "settings", "remove", "more"] - for action in required_actions: - assert f'data-action="{action}"' in body, ( - f"Legend row must have a '{action}' action button" - ) - - def test_legend_listens_for_external_refresh_events(self, tvchart_defaults_js: str): - """Legend must subscribe to pywry:legend-refresh for external updates - (compare add/remove, indicator changes).""" - body = self._fn(tvchart_defaults_js, "_tvSetupLegendControls") - assert "pywry:legend-refresh" in body - - # ------------------------------------------------------------------ - # Volume subplot - # ------------------------------------------------------------------ - - def test_volume_reserve_called_in_both_lifecycle_paths(self, tvchart_defaults_js: str): - """_tvReserveVolumePane must be called in both the static-data (CREATE) - and datafeed-mode code paths to ensure volume always gets a subplot.""" - create = self._create_body(tvchart_defaults_js) - # Static lifecycle must call reserve - assert "_tvReserveVolumePane(entry," in create, ( - "PYWRY_TVCHART_CREATE must call _tvReserveVolumePane for static data" - ) - # The reserve function itself must accept (entry, seriesId) - reserve_fn = self._fn(tvchart_defaults_js, "_tvReserveVolumePane") - assert "entry._volumePaneBySeries" in reserve_fn - # Main volume always gets pane index 1 - assert "paneIndex = 1" in reserve_fn - - def test_volume_pane_height_is_clamped_proportionally(self, tvchart_defaults_js: str): - """_tvApplyDefaultVolumePaneHeight must clamp the height to a reasonable - fraction of the container, not use a fixed pixel value. The formula - prevents the volume pane from being too small or too large.""" - body = self._fn(tvchart_defaults_js, "_tvApplyDefaultVolumePaneHeight") - # Must reference container height for proportional sizing - assert "containerHeight" in body - # Clamp formula: min 64, max 132, 12% of container - assert "Math.max(64" in body - assert "Math.min(132" in body - assert "0.12" in body - # Actually sets the height on the pane - assert "setHeight(desiredHeight)" in body - - def test_volume_options(self, tvchart_defaults_js: str): - """Volume series uses the right-side price scale of its own pane, - keeps the latest-value label visible, and suppresses the price line.""" - body = self._fn(tvchart_defaults_js, "_tvBuildVolumeOptions") - assert "lastValueVisible: true" in body, ( - "Volume needs the latest-value label so the right axis renders ticks" - ) - assert "priceLineVisible: false" in body, ( - "Volume must hide priceLineVisible to avoid horizontal-line clutter" - ) - # Volume series binds to the standard 'right' price scale of its - # own pane (visible by default), not a hidden custom 'volume' scale. - assert "priceScaleId: 'right'" in body - - def test_volume_auto_enables_in_create(self, tvchart_defaults_js: str): - """PYWRY_TVCHART_CREATE enables volume by default when enableVolume is - not explicitly false, and applies the default pane height.""" - create = self._create_body(tvchart_defaults_js) - assert "enableVolume !== false" in create - assert "_tvApplyDefaultVolumePaneHeight(" in create - - # ------------------------------------------------------------------ - # Time range (zoom-only, no interval switching) - # ------------------------------------------------------------------ - - def test_time_range_handler_is_zoom_only(self, tvchart_defaults_js: str): - """The tvchart:time-range handler must apply zoom via - _tvApplyTimeRangeSelection and must NOT switch the interval. - Interval authority belongs exclusively to the interval dropdown.""" - body = self._handler(tvchart_defaults_js, "tvchart:time-range") - assert "_tvApplyTimeRangeSelection(" in body, ( - "time-range handler must call _tvApplyTimeRangeSelection" - ) - # Must not contain interval-switching patterns - assert "_pendingTimeRange" not in body, ( - "time-range handler must not defer to interval change" - ) - assert "targetInterval" not in body, "time-range handler must not switch the interval" - assert "tvchart:interval-change" not in body, ( - "time-range handler must not emit interval-change events" - ) - - def test_time_range_selection_handles_all_and_ytd(self, tvchart_defaults_js: str): - """_tvApplyTimeRangeSelection must have explicit branches for 'all' - (fit all data) and 'ytd' (year-to-date), plus use _tvResolveRangeSpanDays - for named presets like '1y', '3m', etc. For absolute date-range - requests it must delegate to _tvApplyAbsoluteDateRange.""" - body = self._fn(tvchart_defaults_js, "_tvApplyTimeRangeSelection") - assert "range === 'all'" in body - assert "fitContent()" in body - assert "range === 'ytd'" in body - assert "_tvResolveRangeSpanDays(" in body - # Absolute date-range requests are handled by a separate helper. - assert "function _tvApplyAbsoluteDateRange" in tvchart_defaults_js - - def test_range_span_resolver_covers_standard_presets(self, tvchart_defaults_js: str): - """_tvResolveRangeSpanDays must define time spans for all standard presets.""" - body = self._fn(tvchart_defaults_js, "_tvResolveRangeSpanDays") - for preset in ("'1d'", "'5d'", "'1m'", "'3m'", "'6m'", "'1y'", "'5y'"): - assert preset in body, f"Range resolver must cover preset {preset}" - - # No anti-pattern - def test_no_pending_time_range_state(self, tvchart_defaults_js: str): - """_pendingTimeRange was an old pattern that coupled range to interval. - It must not exist anywhere in the codebase.""" - assert "_pendingTimeRange" not in tvchart_defaults_js - - # ------------------------------------------------------------------ - # Legend hover & crosshair - # ------------------------------------------------------------------ - - def test_legend_hover_falls_back_to_cached_data(self, tvchart_defaults_js: str): - """_legendResolveHoveredPoint must try live seriesData first, then - fall back to cached _seriesRawData for cases where seriesData is - unavailable (e.g. compare series, synthetic indicators).""" - body = self._fn(tvchart_defaults_js, "_legendResolveHoveredPoint") - # Try live data first via param.seriesData.get() - assert "param.seriesData" in body - # Fall back to cached raw data - assert "_seriesRawData" in body - # Must handle null/missing time - assert "param.time" in body - - def test_legend_hover_refresh_functions_exist_and_are_called(self, tvchart_defaults_js: str): - """_tvRefreshLegendTitle, _tvEmitLegendRefresh, and _tvRenderHoverLegend - must exist and be called after compare series changes (add/remove).""" - # Functions must exist - for fn in ("_tvRefreshLegendTitle", "_tvEmitLegendRefresh", "_tvRenderHoverLegend"): - assert f"function {fn}(" in tvchart_defaults_js, f"{fn} must be defined" - - # They must each be called from more than just their definition - # (at least 2 occurrences = definition + call site) - for fn in ("_tvRefreshLegendTitle(", "_tvEmitLegendRefresh(", "_tvRenderHoverLegend("): - count = tvchart_defaults_js.count(fn) - assert count >= 2, ( - f"{fn} found {count} time(s) — must be defined AND called from at least one site" - ) - - def test_crosshair_mode_controlled_by_prefs(self, tvchart_defaults_js: str): - """Crosshair visibility must be driven by prefs.crosshairEnabled so - user can toggle it, and default to disabled.""" - body = self._fn(tvchart_defaults_js, "_tvCrosshairLinesVisible") - assert "crosshairEnabled" in body - # _tvApplyHoverReadoutMode must exist to sync crosshair mode to chart - assert "function _tvApplyHoverReadoutMode(" in tvchart_defaults_js - - # ------------------------------------------------------------------ - # Volume divider clearance & scale placement - # ------------------------------------------------------------------ - - def test_divider_clearance_conditionally_expands_bottom_margin(self, tvchart_defaults_js: str): - """_tvEnforceMainScaleDividerClearance must increase the bottom margin - only when a volume pane exists, so the lowest price label does not - overlay the pane divider.""" - body = self._fn(tvchart_defaults_js, "_tvEnforceMainScaleDividerClearance") - # Check for volume pane existence - assert "volumeMap" in body - # Conditional increase - assert "Math.max(bottom" in body - # Must apply via priceScale().applyOptions - assert "scaleMargins" in body - # Must resolve the scale side dynamically - assert "_tvResolveScalePlacement(entry)" in body - - def test_scale_placement_resolver_is_used_at_series_creation(self, tvchart_defaults_js: str): - """_tvResolveScalePlacement must be called wherever series are created - or scale options are applied, so scale-side is never hardcoded.""" - # Function must exist - assert "function _tvResolveScalePlacement(entry)" in tvchart_defaults_js - # Must be called at multiple sites (not just defined) - call_count = tvchart_defaults_js.count("_tvResolveScalePlacement(entry)") - assert call_count >= 3, ( - f"_tvResolveScalePlacement called {call_count} time(s); expected >= 3 " - "(definition + series creation + divider clearance)" - ) - - # ------------------------------------------------------------------ - # Layout save/open (client-side persistence) - # ------------------------------------------------------------------ - - def test_layout_persist_builds_summary_from_contents(self, tvchart_defaults_js: str): - """_tvLayoutPersist must build a summary from indicator/drawing names - for the index entry, not store symbol/timeframe (layouts are portable).""" - body = self._fn(tvchart_defaults_js, "_tvLayoutPersist") - # Builds summary from indicators - assert "indNames" in body or "summary" in body - # Stores to local storage via adapter - assert "_tvStorageSet(" in body - # Index entry has summary field - assert "summary:" in body - # Must NOT store symbol or timeframe - assert "symbol:" not in body - assert "timeframe:" not in body - - def test_layout_apply_restores_drawings_and_indicators(self, tvchart_defaults_js: str): - """_tvApplyLayout must restore drawings and indicators, handle grouped - indicators (Bollinger Bands deduplication), and NOT restore - visibleRange (layouts are portable across charts).""" - body = self._fn(tvchart_defaults_js, "_tvApplyLayout") - # Restores drawings - assert "_tvRenderDrawings(" in body - # Removes old indicators before adding saved ones - assert "_tvRemoveIndicator(" in body - assert "_tvAddIndicator(" in body - # Grouped indicator deduplication - assert "restoredGroups" in body - # Must NOT restore visibleRange - assert ( - "setVisibleLogicalRange" not in body - and "visibleRange" not in body.split("// visibleRange")[0] - ), "Layout apply must not restore visibleRange (portability contract)" - # Restores settings - assert "_tvApplySettingsToChart(" in body - - def test_layout_meta_label_shows_summary_not_symbol(self, tvchart_defaults_js: str): - """_tvLayoutMetaLabel should show summary + date, not symbol/timeframe.""" - body = self._fn(tvchart_defaults_js, "_tvLayoutMetaLabel") - assert "summary" in body - assert "savedAt" in body or "Date" in body - # Must NOT reference symbol or timeframe - assert "symbol" not in body.lower() - assert "timeframe" not in body.lower() - - def test_no_alert_in_layout_flow(self, tvchart_defaults_js: str): - """Layout save/open must use toast notifications, never window.alert.""" - assert "window.alert(" not in tvchart_defaults_js - - # ------------------------------------------------------------------ - # Candle settings: opacity & colour controls - # ------------------------------------------------------------------ - - def test_candle_colours_use_opacity_popup_not_separate_rows(self, tvchart_defaults_js: str): - """Candle body/border/wick colour controls must use the unified - color-opacity popup. Separate addOpacityRow calls for these must - NOT exist (they create redundant UI rows).""" - # The unified popup must exist - assert "function _tvShowColorOpacityPopup(" in tvchart_defaults_js - # Old per-element opacity rows must NOT be used for candle parts - for part in ("Body", "Borders", "Wick"): - assert f"addOpacityRow(lineSection, '{part}'" not in tvchart_defaults_js, ( - f"Candle {part} must use color-opacity popup, not a separate opacity row" - ) - # Combined opacity keys must exist for all candle parts - for part in ("Body", "Borders", "Wick"): - assert f"'{part}-Opacity'" in tvchart_defaults_js - - def test_candle_colour_with_opacity_applied_for_all_parts(self, tvchart_defaults_js: str): - """All six candle colour keys (Body/Borders/Wick x Up/Down) must be - passed through _tvColorWithOpacity so opacity is actually applied.""" - parts = [ - "Body-Up Color", - "Body-Down Color", - "Borders-Up Color", - "Borders-Down Color", - "Wick-Up Color", - "Wick-Down Color", - ] - for part in parts: - assert f"_tvColorWithOpacity(settings['{part}']" in tvchart_defaults_js, ( - f"Missing _tvColorWithOpacity call for '{part}'" - ) - - def test_settings_collect_hidden_inputs_for_opacity(self, tvchart_defaults_js: str): - """collectSettingsFromPanel must read hidden inputs (used for opacity - sliders) in addition to number/text/range controls.""" - assert ( - "ctrl.type === 'number' || ctrl.type === 'text' || ctrl.type === 'range' || ctrl.type === 'hidden'" - in tvchart_defaults_js - ) - - # ------------------------------------------------------------------ - # Status-line and scales settings - # ------------------------------------------------------------------ - - def test_settings_row_helpers_exist(self, tvchart_defaults_js: str): - """Shared settings row builders must be defined so all settings tabs - have consistent layout and control alignment.""" - helpers = [ - "addIndentedCheckboxRow(parent, label, checked)", - "addCheckboxSliderRow(parent, label, checked, enabledSetting, sliderValue, sliderSetting)", - "addNumberInputRow(parent, label, settingKey, value, min, max, step, unitText, inputClassName)", - "addColorSwatchRow(parent, label, color, settingKey)", - "addCheckboxInputRow(parent, label, checked, enabledSetting, inputValue, inputSetting)", - "addSelectColorRow(parent, label, options, selected, selectSetting, color, colorSetting)", - ] - for sig in helpers: - assert f"function {sig}" in tvchart_defaults_js, ( - f"Settings helper 'function {sig}' must be defined" - ) - - def test_scales_settings_uses_full_value_label(self, tvchart_defaults_js: str): - """The scales tab must use the full 'Value according to scale' label. - A truncated 'Value according to sc...' label broke the settings key - mapping. Fallback for the truncated key must also exist.""" - assert "'Value according to scale'" in tvchart_defaults_js - # The truncated key must NOT be used in addSelectRow calls - assert "addSelectRow(scalesSection, 'Value according to sc...'" not in tvchart_defaults_js - # Fallback for layouts saved with the truncated key - assert "'Value according to sc...'" in tvchart_defaults_js - - # ------------------------------------------------------------------ - # Settings preview & cancel - # ------------------------------------------------------------------ - - def test_settings_preview_pipeline(self, tvchart_defaults_js: str): - """Settings must clone originals for revert, schedule a preview on - input/change events, and revert to originals on cancel.""" - # Original settings cloned for cancel-revert - assert "JSON.parse(JSON.stringify(currentSettings" in tvchart_defaults_js - # Preview functions exist - for fn_name in ("collectSettingsFromPanel", "scheduleSettingsPreview", "persistSettings"): - assert f"function {fn_name}(" in tvchart_defaults_js, ( - f"Settings preview pipeline requires '{fn_name}'" - ) - # Cancel reverts to original - assert "_tvApplySettingsToChart(chartId, entry, originalSettings)" in tvchart_defaults_js - # Preview triggered on user input - assert "addEventListener('input'" in tvchart_defaults_js - assert "addEventListener('change'" in tvchart_defaults_js - - # ------------------------------------------------------------------ - # Chart navigation (scroll/zoom always enabled) - # ------------------------------------------------------------------ - - def test_navigation_disable_restore_symmetry(self, tvchart_defaults_js: str): - """Drawing drag temporarily disables chart navigation. Every disable - call must have a matching restore call — an imbalance would leave - the chart in a broken non-interactive state.""" - disable_count = tvchart_defaults_js.count( - "entry.chart.applyOptions({ handleScroll: false, handleScale: false })" - ) - restore_count = tvchart_defaults_js.count( - "entry.chart.applyOptions({ handleScroll: true, handleScale: true })" - ) - assert disable_count >= 1, "At least one navigation disable expected for drawing drag" - assert disable_count == restore_count, ( - f"Disable ({disable_count}) and restore ({restore_count}) must be symmetric" - ) - - def test_ensure_interactive_navigation_exists_and_is_called(self, tvchart_defaults_js: str): - """_tvEnsureInteractiveNavigation must exist (restores navigation after - overlays close) and be called from at least one site.""" - body = self._fn(tvchart_defaults_js, "_tvEnsureInteractiveNavigation") - # Must re-enable both scroll and scale options - assert "handleScroll" in body or "applyOptions" in body - # Must be called from other code (not just defined) - all_calls = tvchart_defaults_js.count("_tvEnsureInteractiveNavigation(entry)") - assert all_calls >= 2, ( - f"_tvEnsureInteractiveNavigation called {all_calls} time(s); expected >= 2 " - "(definition + at least one call site)" - ) - - def test_interactive_navigation_options_enable_all_inputs(self, tvchart_defaults_js: str): - """_tvInteractiveNavigationOptions must enable mouse wheel, pressed - mouse move (pan), and pinch zoom.""" - body = self._fn(tvchart_defaults_js, "_tvInteractiveNavigationOptions") - for opt in ("mouseWheel: true", "pressedMouseMove: true", "pinch: true"): - assert opt in body, f"Interactive navigation must have {opt}" - - # ------------------------------------------------------------------ - # Chart-type change: ordering & scoping - # ------------------------------------------------------------------ - - def test_chart_type_change_adds_new_series_before_removing_old(self, tvchart_defaults_js: str): - """Chart-type switching must add the replacement series BEFORE removing - the old one. If the old (sole) series in a pane is removed first the - pane is destroyed and renumbered by Lightweight Charts, causing the new - series to land in the wrong pane.""" - handler = self._handler(tvchart_defaults_js, "tvchart:chart-type-change") - add_pos = handler.index("_tvAddSeriesCompat(entry.chart,") - remove_pos = handler.index("entry.chart.removeSeries(oldSeries)") - assert add_pos < remove_pos, ( - "chart-type-change handler must call _tvAddSeriesCompat BEFORE " - "removeSeries to prevent pane collapse" - ) - - def test_settings_rebuild_adds_new_series_before_removing_old(self, tvchart_defaults_js: str): - """Series-settings OK handler must also add-then-remove to keep the - pane alive while replacing the series object.""" - anchor = "_tvAddSeriesCompat(entry.chart, targetType, rebuiltOptions" - settings_start = tvchart_defaults_js.index(anchor) - region = tvchart_defaults_js[max(0, settings_start - 600) : settings_start + 600] - add_pos = region.index(anchor) - remove_pos = region.index("entry.chart.removeSeries(oldSeries)") - assert add_pos < remove_pos, ( - "series-settings rebuild must call _tvAddSeriesCompat BEFORE " - "removeSeries to prevent pane collapse" - ) - - def test_chart_type_change_handler_scoped_to_single_chart(self, tvchart_defaults_js: str): - """Chart-type changes must target a single resolved chart entry, never - iterate over all charts globally.""" - handler = self._handler(tvchart_defaults_js, "tvchart:chart-type-change") - assert "_tvResolveChartEntry(" in handler, ( - "handler must use _tvResolveChartEntry to scope to one chart" - ) - assert "Object.keys(window.__PYWRY_TVCHARTS__)" not in handler, ( - "handler must NOT iterate over all charts" - ) - - # ------------------------------------------------------------------ - # Baseline series & chart creation - # ------------------------------------------------------------------ - - def test_baseline_series_computes_base_value_in_both_paths(self, tvchart_defaults_js: str): - """Baseline series must compute baseValue from data in both the - chart-type-change handler AND the initial CREATE path. Without this, - baseValue defaults to 0 and all data renders above the baseline.""" - assert "function _tvComputeBaselineValue(bars, pct)" in tvchart_defaults_js - - handler = self._handler(tvchart_defaults_js, "tvchart:chart-type-change") - assert "_tvComputeBaselineValue(" in handler, ( - "chart-type-change must compute baseValue for Baseline type" - ) - - create = self._create_body(tvchart_defaults_js) - assert "_tvComputeBaselineValue(" in create, ( - "PYWRY_TVCHART_CREATE must compute baseValue for Baseline type" - ) - - def test_create_branches_on_datafeed_mode(self, tvchart_defaults_js: str): - """PYWRY_TVCHART_CREATE must branch on payload.useDatafeed to select - between static-data and streaming-datafeed initialisation.""" - create = self._create_body(tvchart_defaults_js) - assert "payload.useDatafeed" in create - assert "_tvInitDatafeedMode(" in create - - def test_datafeed_init_orchestrates_full_protocol(self, tvchart_defaults_js: str): - """_tvInitDatafeedMode must create a datafeed and call all required - TradingView Datafeed API methods (onReady, resolveSymbol, getBars, - subscribeBars) to establish the streaming connection.""" - body = self._fn(tvchart_defaults_js, "_tvInitDatafeedMode") - assert "_tvCreateDatafeed(" in body - for method in ("onReady", "resolveSymbol", "getBars", "subscribeBars"): - assert f"datafeed.{method}(" in body, ( - f"_tvInitDatafeedMode must call datafeed.{method}()" - ) - - # ------------------------------------------------------------------ - # Layout export (no raw data, portable) - # ------------------------------------------------------------------ - - def test_layout_export_excludes_raw_data_and_visible_range(self, tvchart_defaults_js: str): - """_tvExportLayout must export indicators, drawings, and settings - but NOT rawData or visibleRange (layouts are portable).""" - body = self._fn(tvchart_defaults_js, "_tvExportLayout") - # Must export these - assert "indicators" in body - assert "drawings" in body - assert "settings" in body or "_tvBuildCurrentSettings" in body - # Must NOT include raw bar data (that's for state, not layout) - assert "rawData:" not in body - # Must NOT include visibleRange - assert "visibleRange:" not in body - - def test_layout_export_preserves_grouped_indicator_metadata(self, tvchart_defaults_js: str): - """_tvExportLayout must preserve group-specific metadata like multiplier, - maType, offset, and source so grouped indicators (e.g. Bollinger Bands) - can be faithfully restored.""" - body = self._fn(tvchart_defaults_js, "_tvExportLayout") - for field in ("multiplier", "maType", "offset", "source"): - assert field in body, f"_tvExportLayout must preserve '{field}' for grouped indicators" - - -# ============================================================================= -# TVChartStateMixin tests -# ============================================================================= - - -class _MockEmitter(TVChartStateMixin): - """Concrete class for testing the mixin.""" - - def __init__(self): - self._emitted: list[tuple[str, Any]] = [] - - def emit(self, event_type: str, data: dict[str, Any] | None = None) -> None: - self._emitted.append((event_type, data)) - - -class TestTVChartStateMixin: - """Test TVChartStateMixin methods emit correct events.""" - - def test_update_series(self): - m = _MockEmitter() - m.update_series([{"time": 1, "open": 1, "high": 2, "low": 0, "close": 1}]) - assert len(m._emitted) == 1 - event, payload = m._emitted[0] - assert event == "tvchart:update" - assert "bars" in payload - assert payload["fitContent"] is True - - def test_update_bar(self): - m = _MockEmitter() - bar = {"time": 1700000000, "open": 100, "high": 105, "low": 98, "close": 103} - m.update_bar(bar) - assert len(m._emitted) == 1 - event, payload = m._emitted[0] - assert event == "tvchart:stream" - assert payload["bar"] is bar - - def test_update_bar_with_volume(self): - m = _MockEmitter() - bar = { - "time": 1700000000, - "open": 100, - "high": 105, - "low": 98, - "close": 103, - "volume": 5000, - } - m.update_bar(bar) - _event, payload = m._emitted[0] - assert "volume" in payload - assert payload["volume"]["value"] == 5000 - - def test_add_indicator(self): - m = _MockEmitter() - indicator_data = [{"time": 1, "value": 50}, {"time": 2, "value": 55}] - m.add_indicator(indicator_data, series_id="sma20", series_type="Line") - event, payload = m._emitted[0] - assert event == "tvchart:add-series" - assert payload["seriesId"] == "sma20" - assert payload["seriesType"] == "Line" - assert payload["bars"] is indicator_data - - def test_remove_indicator(self): - m = _MockEmitter() - m.remove_indicator("sma20") - event, payload = m._emitted[0] - assert event == "tvchart:remove-series" - assert payload["seriesId"] == "sma20" - - # -------- built-in indicator engine (JS-side compute) --------------- - - def test_add_builtin_indicator_minimal(self): - m = _MockEmitter() - m.add_builtin_indicator("RSI") - event, payload = m._emitted[0] - assert event == "tvchart:add-indicator" - assert payload == {"name": "RSI"} - - def test_add_builtin_indicator_with_period_and_color(self): - m = _MockEmitter() - m.add_builtin_indicator("Moving Average", period=50, color="#2196F3", method="SMA") - event, payload = m._emitted[0] - assert event == "tvchart:add-indicator" - assert payload["name"] == "Moving Average" - assert payload["method"] == "SMA" - assert payload["period"] == 50 - assert payload["color"] == "#2196F3" - - def test_add_builtin_indicator_passes_bollinger_options(self): - m = _MockEmitter() - m.add_builtin_indicator( - "Bollinger Bands", - period=20, - multiplier=2.0, - ma_type="SMA", - offset=0, - source="close", - ) - _event, payload = m._emitted[0] - # Note: ma_type → maType in payload (per the wire contract) - assert payload["multiplier"] == 2.0 - assert payload["maType"] == "SMA" - assert payload["offset"] == 0 - assert payload["source"] == "close" - - def test_add_builtin_indicator_omits_unset_options(self): - m = _MockEmitter() - m.add_builtin_indicator("RSI", period=12) - _event, payload = m._emitted[0] - # Only the explicit fields land in the payload - assert set(payload.keys()) == {"name", "period"} - - def test_add_builtin_indicator_chart_id(self): - m = _MockEmitter() - m.add_builtin_indicator("Moving Average", period=10, method="SMA", chart_id="alt") - _event, payload = m._emitted[0] - assert payload["chartId"] == "alt" - - def test_add_builtin_indicator_with_method(self): - m = _MockEmitter() - m.add_builtin_indicator("Moving Average", period=14, method="EMA") - _event, payload = m._emitted[0] - assert payload["method"] == "EMA" - - def test_remove_builtin_indicator(self): - m = _MockEmitter() - m.remove_builtin_indicator("ind_sma_99") - event, payload = m._emitted[0] - assert event == "tvchart:remove-indicator" - assert payload == {"seriesId": "ind_sma_99"} - - def test_remove_builtin_indicator_with_chart_id(self): - m = _MockEmitter() - m.remove_builtin_indicator("ind_sma_99", chart_id="alt") - _event, payload = m._emitted[0] - assert payload["chartId"] == "alt" - - def test_list_indicators_default(self): - m = _MockEmitter() - m.list_indicators() - event, payload = m._emitted[0] - assert event == "tvchart:list-indicators" - assert payload == {} - - def test_list_indicators_with_context(self): - m = _MockEmitter() - m.list_indicators(chart_id="alt", context={"trigger": "init"}) - _event, payload = m._emitted[0] - assert payload["chartId"] == "alt" - assert payload["context"] == {"trigger": "init"} - - def test_add_marker(self): - m = _MockEmitter() - markers = [ - { - "time": 1, - "position": "aboveBar", - "shape": "arrowDown", - "color": "red", - "text": "Sell", - } - ] - m.add_marker(markers) - event, payload = m._emitted[0] - assert event == "tvchart:add-markers" - assert payload["markers"] is markers - - def test_add_price_line(self): - m = _MockEmitter() - m.add_price_line(150.0, color="#ff0000", title="Resistance") - event, payload = m._emitted[0] - assert event == "tvchart:add-price-line" - assert payload["price"] == 150.0 - assert payload["color"] == "#ff0000" - assert payload["title"] == "Resistance" - - def test_set_visible_range(self): - m = _MockEmitter() - m.set_visible_range(1700000000, 1700500000) - event, payload = m._emitted[0] - assert event == "tvchart:time-scale" - assert payload["visibleRange"]["from"] == 1700000000 - assert payload["visibleRange"]["to"] == 1700500000 - - def test_fit_content(self): - m = _MockEmitter() - m.fit_content() - event, payload = m._emitted[0] - assert event == "tvchart:time-scale" - assert payload["fitContent"] is True - - def test_apply_chart_options(self): - m = _MockEmitter() - m.apply_chart_options(chart_options={"layout": {"background": {"color": "#000"}}}) - event, payload = m._emitted[0] - assert event == "tvchart:apply-options" - assert "chartOptions" in payload - - def test_request_tvchart_state(self): - m = _MockEmitter() - m.request_tvchart_state(chart_id="chart1") - event, payload = m._emitted[0] - assert event == "tvchart:request-state" - assert payload["chartId"] == "chart1" - - def test_request_tvchart_state_with_context(self): - m = _MockEmitter() - m.request_tvchart_state( - chart_id="chart1", context={"target_view": "watchlist", "reason": "reload"} - ) - event, payload = m._emitted[0] - assert event == "tvchart:request-state" - assert payload["chartId"] == "chart1" - assert payload["context"] == {"target_view": "watchlist", "reason": "reload"} - - def test_chart_id_propagation(self): - m = _MockEmitter() - m.update_series([], chart_id="chart42") - _, payload = m._emitted[0] - assert payload["chartId"] == "chart42" - - def test_series_id_propagation(self): - m = _MockEmitter() - m.update_series([], series_id="overlay") - _, payload = m._emitted[0] - assert payload["seriesId"] == "overlay" - - def test_request_tvchart_symbol_search(self): - m = _MockEmitter() - m.request_tvchart_symbol_search( - "aapl", - request_id="req-s-1", - chart_id="chart1", - limit=7, - exchange="NASDAQ", - symbol_type="stock", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-search-request" - assert payload["query"] == "aapl" - assert payload["requestId"] == "req-s-1" - assert payload["chartId"] == "chart1" - assert payload["limit"] == 7 - assert payload["exchange"] == "NASDAQ" - assert payload["symbolType"] == "stock" - - def test_respond_tvchart_symbol_search(self): - m = _MockEmitter() - m.respond_tvchart_symbol_search( - request_id="req-s-2", - items=[{"symbol": "NASDAQ:AAPL", "fullName": "Apple Inc."}], - chart_id="chart1", - query="app", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-search-response" - assert payload["requestId"] == "req-s-2" - assert payload["items"][0]["symbol"] == "NASDAQ:AAPL" - assert payload["chartId"] == "chart1" - assert payload["query"] == "app" - - def test_request_and_respond_tvchart_symbol_resolve(self): - m = _MockEmitter() - m.request_tvchart_symbol_resolve("NASDAQ:AAPL", request_id="req-r-1", chart_id="chart1") - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-resolve-request" - assert payload["symbol"] == "NASDAQ:AAPL" - assert payload["requestId"] == "req-r-1" - - m.respond_tvchart_symbol_resolve( - request_id="req-r-1", - symbol_info={"symbol": "NASDAQ:AAPL", "fullName": "Apple Inc."}, - chart_id="chart1", - ) - event, payload = m._emitted[1] - assert event == "tvchart:datafeed-resolve-response" - assert payload["requestId"] == "req-r-1" - assert payload["symbolInfo"]["symbol"] == "NASDAQ:AAPL" - - def test_request_and_respond_tvchart_history(self): - m = _MockEmitter() - m.request_tvchart_history( - symbol="NASDAQ:AAPL", - resolution="1D", - from_time=1_700_000_000, - to_time=1_700_086_400, - request_id="req-h-1", - chart_id="chart1", - count_back=300, - first_data_request=True, - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-history-request" - assert payload["symbol"] == "NASDAQ:AAPL" - assert payload["resolution"] == "1D" - assert payload["from"] == 1_700_000_000 - assert payload["to"] == 1_700_086_400 - assert payload["countBack"] == 300 - assert payload["firstDataRequest"] is True - - m.respond_tvchart_history( - request_id="req-h-1", - bars=[{"time": 1_700_000_000, "value": 123.4}], - status="ok", - chart_id="chart1", - no_data=False, - next_time=1_699_900_000, - ) - event, payload = m._emitted[1] - assert event == "tvchart:datafeed-history-response" - assert payload["requestId"] == "req-h-1" - assert payload["status"] == "ok" - assert payload["bars"][0]["value"] == 123.4 - assert payload["noData"] is False - assert payload["nextTime"] == 1_699_900_000 - - def test_respond_tvchart_datafeed_config(self): - m = _MockEmitter() - m.respond_tvchart_datafeed_config( - request_id="cfg-1", - config={"supported_resolutions": ["1", "5", "1D"], "supports_marks": True}, - chart_id="chart1", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-config-response" - assert payload["requestId"] == "cfg-1" - assert payload["config"]["supports_marks"] is True - - def test_respond_tvchart_bar_update(self): - m = _MockEmitter() - m.respond_tvchart_bar_update( - listener_guid="guid-abc", - bar={ - "time": 1700000000000, - "open": 100, - "high": 105, - "low": 99, - "close": 103, - "volume": 50000, - }, - chart_id="chart1", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-bar-update" - assert payload["listenerGuid"] == "guid-abc" - assert payload["bar"]["close"] == 103 - - def test_respond_tvchart_reset_cache(self): - m = _MockEmitter() - m.respond_tvchart_reset_cache(listener_guid="guid-abc", chart_id="chart1") - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-reset-cache" - assert payload["listenerGuid"] == "guid-abc" - - def test_respond_tvchart_marks(self): - m = _MockEmitter() - m.respond_tvchart_marks( - request_id="m-1", - marks=[{"id": "mk1", "time": 1700000000, "color": "red", "text": "Buy"}], - chart_id="chart1", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-marks-response" - assert payload["requestId"] == "m-1" - - -class _WirableEmitter(_MockEmitter): - """Mock emitter with on() support for wiring tests.""" - - def __init__(self): - super().__init__() - self._handlers: dict[str, list] = {} - - def on(self, event: str, callback, label=None): - self._handlers.setdefault(event, []).append(callback) - - def fire(self, event: str, data: dict): - for cb in self._handlers.get(event, []): - cb(data, event, "test-label") - - -class TestDatafeedDataRequestWiring: - """Verify that _wire_datafeed_provider registers a data-request handler.""" - - def test_data_request_handler_registered(self): - from unittest.mock import AsyncMock - - from pywry.tvchart.datafeed import DatafeedProvider - - provider = AsyncMock(spec=DatafeedProvider) - provider.supports_search = False - provider.get_config = AsyncMock(return_value={}) - - m = _WirableEmitter() - m._wire_datafeed_provider(provider) - - assert "tvchart:data-request" in m._handlers - - def test_data_request_echoes_interval(self): - from unittest.mock import AsyncMock - - from pywry.tvchart.datafeed import DatafeedProvider - - provider = AsyncMock(spec=DatafeedProvider) - provider.supports_search = False - provider.get_config = AsyncMock(return_value={}) - provider.get_bars = AsyncMock( - return_value={ - "bars": [{"time": 1700000000, "open": 100, "high": 105, "low": 98, "close": 103}], - "status": "ok", - } - ) - - m = _WirableEmitter() - m._wire_datafeed_provider(provider) - - m.fire( - "tvchart:data-request", - { - "chartId": "main", - "interval": "6M", - "symbol": "AAPL", - "periodParams": {"from": 0, "to": 1700000000, "countBack": 300}, - }, - ) - - assert len(m._emitted) == 1 - event, payload = m._emitted[0] - assert event == "tvchart:data-response" - assert payload["interval"] == "6M" - assert len(payload["bars"]) == 1 - assert payload["chartId"] == "main" - # Verify provider.get_bars was called with correct args - provider.get_bars.assert_called_once_with("AAPL", "6M", 0, 1700000000, 300) - - -class TestTVChartStateMixinResponders: - """Remaining responder tests (continued from TestTVChartStateMixin).""" - - def test_respond_tvchart_timescale_marks(self): - m = _MockEmitter() - m.respond_tvchart_timescale_marks( - request_id="ts-1", - marks=[ - {"id": "ts1", "time": 1700000000, "color": "blue", "label": "D", "tooltip": ["Div"]} - ], - chart_id="chart1", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-timescale-marks-response" - assert payload["requestId"] == "ts-1" - assert len(payload["marks"]) == 1 - - def test_respond_tvchart_server_time(self): - m = _MockEmitter() - m.respond_tvchart_server_time( - request_id="st-1", - time=1700000000, - chart_id="chart1", - ) - event, payload = m._emitted[0] - assert event == "tvchart:datafeed-server-time-response" - assert payload["requestId"] == "st-1" - assert payload["time"] == 1700000000 - - def test_normalize_tvchart_data_list(self): - bars = [{"time": 1, "open": 1, "high": 2, "low": 0, "close": 1}] - result_bars, result_vol = TVChartStateMixin._normalize_tvchart_data(bars) - assert result_bars is bars - assert result_vol == [] - - -# ============================================================================= -# Widget class shape tests -# ============================================================================= - - -class TestPyWryTVChartWidgetShape: - """Verify the TVChart widget class exists with expected attributes.""" - - def test_class_exists(self): - from pywry.widget import PyWryTVChartWidget - - assert PyWryTVChartWidget is not None - - def test_fallback_instantiation(self): - """The widget can be instantiated (at minimum as fallback).""" - from pywry.widget import PyWryTVChartWidget - - w = PyWryTVChartWidget(content="
test
") - assert hasattr(w, "content") - - def test_has_emit_method(self): - from pywry.widget import PyWryTVChartWidget - - w = PyWryTVChartWidget(content="") - assert callable(getattr(w, "emit", None)) - - def test_has_on_method(self): - from pywry.widget import PyWryTVChartWidget - - w = PyWryTVChartWidget(content="") - assert callable(getattr(w, "on", None)) - - -# ============================================================================= -# Public API import tests -# ============================================================================= - - -class TestPublicAPIImports: - """Verify all new symbols are exported from pywry.__init__.""" - - def test_import_tvchart_config(self): - from pywry import TVChartConfig as PublicTVChartConfig - - assert PublicTVChartConfig is not None - - def test_import_tvchart_data(self): - from pywry import TVChartData as PublicTVChartData - - assert PublicTVChartData is not None - - def test_import_tvchart_datafeed_models(self): - from pywry import ( - TVChartBar, - TVChartDatafeedBarUpdate, - TVChartDatafeedConfigRequest, - TVChartDatafeedConfigResponse, - TVChartDatafeedConfiguration, - TVChartDatafeedHistoryRequest, - TVChartDatafeedHistoryResponse, - TVChartDatafeedMarksRequest, - TVChartDatafeedMarksResponse, - TVChartDatafeedResolveRequest, - TVChartDatafeedResolveResponse, - TVChartDatafeedSearchRequest, - TVChartDatafeedSearchResponse, - TVChartDatafeedServerTimeRequest, - TVChartDatafeedServerTimeResponse, - TVChartDatafeedSubscribeRequest, - TVChartDatafeedSymbolType, - TVChartDatafeedTimescaleMarksRequest, - TVChartDatafeedTimescaleMarksResponse, - TVChartDatafeedUnsubscribeRequest, - TVChartExchange, - TVChartLibrarySubsessionInfo, - TVChartMark, - TVChartSearchSymbolResultItem, - TVChartSymbolInfo, - TVChartSymbolInfoPriceSource, - TVChartTimescaleMark, - ) - - assert TVChartBar is not None - assert TVChartDatafeedBarUpdate is not None - assert TVChartDatafeedConfigRequest is not None - assert TVChartDatafeedConfigResponse is not None - assert TVChartDatafeedConfiguration is not None - assert TVChartDatafeedHistoryRequest is not None - assert TVChartDatafeedHistoryResponse is not None - assert TVChartDatafeedMarksRequest is not None - assert TVChartDatafeedMarksResponse is not None - assert TVChartDatafeedResolveRequest is not None - assert TVChartDatafeedResolveResponse is not None - assert TVChartDatafeedSearchRequest is not None - assert TVChartDatafeedSearchResponse is not None - assert TVChartDatafeedServerTimeRequest is not None - assert TVChartDatafeedServerTimeResponse is not None - assert TVChartDatafeedSubscribeRequest is not None - assert TVChartDatafeedSymbolType is not None - assert TVChartDatafeedTimescaleMarksRequest is not None - assert TVChartDatafeedTimescaleMarksResponse is not None - assert TVChartDatafeedUnsubscribeRequest is not None - assert TVChartExchange is not None - assert TVChartLibrarySubsessionInfo is not None - assert TVChartMark is not None - assert TVChartSearchSymbolResultItem is not None - assert TVChartSymbolInfo is not None - assert TVChartSymbolInfoPriceSource is not None - assert TVChartTimescaleMark is not None - - def test_import_tvchart_state_mixin(self): - from pywry import TVChartStateMixin - - assert TVChartStateMixin is not None - - def test_import_pywry_tvchart_widget(self): - from pywry import PyWryTVChartWidget - - assert PyWryTVChartWidget is not None - - def test_import_show_tvchart(self): - from pywry import show_tvchart - - assert callable(show_tvchart) - - def test_import_build_tvchart_toolbars(self): - from pywry import build_tvchart_toolbars - - assert callable(build_tvchart_toolbars) - - def test_all_in_dunder_all(self): - import pywry - - all_names = pywry.__all__ - for name in [ - "TVChartBar", - "TVChartConfig", - "TVChartData", - "TVChartDatafeedBarUpdate", - "TVChartDatafeedConfigRequest", - "TVChartDatafeedConfigResponse", - "TVChartDatafeedConfiguration", - "TVChartDatafeedHistoryRequest", - "TVChartDatafeedHistoryResponse", - "TVChartDatafeedMarksRequest", - "TVChartDatafeedMarksResponse", - "TVChartDatafeedResolveRequest", - "TVChartDatafeedResolveResponse", - "TVChartDatafeedSearchRequest", - "TVChartDatafeedSearchResponse", - "TVChartDatafeedServerTimeRequest", - "TVChartDatafeedServerTimeResponse", - "TVChartDatafeedSubscribeRequest", - "TVChartDatafeedSymbolType", - "TVChartDatafeedTimescaleMarksRequest", - "TVChartDatafeedTimescaleMarksResponse", - "TVChartDatafeedUnsubscribeRequest", - "TVChartExchange", - "TVChartLibrarySubsessionInfo", - "TVChartMark", - "TVChartSearchSymbolResultItem", - "TVChartSymbolInfo", - "TVChartSymbolInfoPriceSource", - "TVChartTimescaleMark", - "TVChartStateMixin", - "PyWryTVChartWidget", - "show_tvchart", - "build_tvchart_toolbars", - ]: - assert name in all_names, f"{name} not in __all__" - - -# ============================================================================= -# show_tvchart wiring tests -# ============================================================================= - - -class TestShowTVChartSignature: - """Test that show_tvchart has the expected signature.""" - - def test_signature_params(self): - import inspect - - from pywry.inline import show_tvchart - - sig = inspect.signature(show_tvchart) - params = list(sig.parameters.keys()) - assert "data" in params - assert "callbacks" in params - assert "title" in params - assert "width" in params - assert "height" in params - assert "theme" in params - assert "chart_options" in params - assert "series_options" in params - assert "symbol_col" in params - assert "max_bars" in params - assert "toolbars" in params - assert "use_datafeed" in params - assert "symbol" in params - assert "resolution" in params - - def test_data_defaults_to_none(self): - import inspect - - from pywry.inline import show_tvchart - - sig = inspect.signature(show_tvchart) - assert sig.parameters["data"].default is None - - def test_use_datafeed_defaults_to_false(self): - import inspect - - from pywry.inline import show_tvchart - - sig = inspect.signature(show_tvchart) - assert sig.parameters["use_datafeed"].default is False - - def test_resolution_defaults_to_1d(self): - import inspect - - from pywry.inline import show_tvchart - - sig = inspect.signature(show_tvchart) - assert sig.parameters["resolution"].default == "1D" - - -# ============================================================================= -# Indicator catalog + compute + recompute coverage -# ============================================================================= - - -class TestTVChartIndicatorCatalog: - """Every indicator advertised by the catalog must have: - - * a compute function present in the bundled JS, - * an add-indicator branch that creates its series, and - * a recompute branch in ``_tvRecomputeIndicatorSeries`` so it refreshes - when underlying bars change (otherwise indicators silently freeze at - their initial snapshot when the datafeed replaces bars — exactly the - bug that made VWAP show 9.99 on a $270 stock). - """ - - @pytest.fixture - def js(self) -> str: - from pywry.assets import get_tvchart_defaults_js - - return get_tvchart_defaults_js() - - # ------------------------------------------------------------------ - # Catalog entries - # ------------------------------------------------------------------ - - EXPECTED_CATALOG_NAMES = ( - "Moving Average", - "Ichimoku Cloud", - "Bollinger Bands", - "Keltner Channels", - "ATR", - "Historical Volatility", - "Parabolic SAR", - "RSI", - "MACD", - "Stochastic", - "Williams %R", - "CCI", - "ADX", - "Aroon", - "VWAP", - "Volume SMA", - "Accumulation/Distribution", - "Volume Profile Fixed Range", - "Volume Profile Visible Range", - ) - - @pytest.mark.parametrize("name", EXPECTED_CATALOG_NAMES) - def test_catalog_contains_indicator(self, js: str, name: str) -> None: - cat_start = js.index("_INDICATOR_CATALOG = [") - cat_end = js.index("];", cat_start) - catalog_src = js[cat_start:cat_end] - assert f"name: '{name}'" in catalog_src, f"Indicator catalog missing entry for '{name}'" - - def test_volume_profile_entries_are_primitive(self, js: str) -> None: - cat_start = js.index("_INDICATOR_CATALOG = [") - cat_end = js.index("];", cat_start) - catalog_src = js[cat_start:cat_end] - for key in ("'volume-profile-fixed'", "'volume-profile-visible'"): - block = catalog_src[catalog_src.index(key) :] - first_close = block.index("}") - entry = block[:first_close] - assert "primitive: true" in entry, f"Expected VP entry {key} to have primitive: true" - - # ------------------------------------------------------------------ - # Compute functions - # ------------------------------------------------------------------ - - EXPECTED_COMPUTE_FNS = ( - "_computeSMA", - "_computeEMA", - "_computeWMA", - "_computeHMA", - "_computeVWMA", - "_computeRSI", - "_computeATR", - "_computeBollingerBands", - "_computeKeltnerChannels", - "_computeVWAP", - "_computeMACD", - "_computeStochastic", - "_computeAroon", - "_computeADX", - "_computeCCI", - "_computeWilliamsR", - "_computeAccumulationDistribution", - "_computeHistoricalVolatility", - "_computeIchimoku", - "_computeParabolicSAR", - ) - - @pytest.mark.parametrize("fn_name", EXPECTED_COMPUTE_FNS) - def test_compute_function_defined(self, js: str, fn_name: str) -> None: - assert f"function {fn_name}(" in js, f"Missing compute function {fn_name} in bundled JS" - - # ------------------------------------------------------------------ - # Add-indicator branches - # ------------------------------------------------------------------ - - ADD_BRANCHES = ( - ("name === 'VWAP'", "_computeVWAP"), - ("name === 'MACD'", "_computeMACD"), - ("name === 'Stochastic'", "_computeStochastic"), - ("name === 'Aroon'", "_computeAroon"), - ("name === 'ADX'", "_computeADX"), - ("name === 'CCI'", "_computeCCI"), - ("name === 'Williams %R'", "_computeWilliamsR"), - ("name === 'Accumulation/Distribution'", "_computeAccumulationDistribution"), - ("name === 'Historical Volatility'", "_computeHistoricalVolatility"), - ("name === 'Keltner Channels'", "_computeKeltnerChannels"), - ("name === 'Ichimoku Cloud'", "_computeIchimoku"), - ("name === 'Parabolic SAR'", "_computeParabolicSAR"), - ) - - @pytest.mark.parametrize("branch,fn", ADD_BRANCHES) - def test_add_branch_wires_compute(self, js: str, branch: str, fn: str) -> None: - assert branch in js, f"Missing add-indicator branch '{branch}' in 04-series.js" - # Narrow the search: compute call must appear after the branch and - # before the next `} else if (name ===` marker. - branch_idx = js.index(branch) - next_branch = js.find("} else if (name ===", branch_idx + 1) - if next_branch < 0: - next_branch = js.find("_tvAddIndicator fallthrough", branch_idx + 1) - segment = js[branch_idx : next_branch if next_branch > 0 else branch_idx + 2000] - assert fn in segment, ( - f"Branch for '{branch}' should call {fn}() but didn't within 2000 chars" - ) - - # ------------------------------------------------------------------ - # Recompute branches (THIS is the bug that caused VWAP=9.99) - # ------------------------------------------------------------------ - - @pytest.fixture - def recompute_body(self, js: str) -> str: - start = js.index("function _tvRecomputeIndicatorSeries(") - # Find matching close brace for the function - depth = 0 - i = js.index("{", start) - n = len(js) - while i < n: - ch = js[i] - if ch == "{": - depth += 1 - elif ch == "}": - depth -= 1 - if depth == 0: - return js[start : i + 1] - i += 1 - raise RuntimeError("Could not find end of _tvRecomputeIndicatorSeries") - - RECOMPUTE_BRANCHES = ( - ("info.name === 'VWAP'", "_computeVWAP"), - ("info.name === 'CCI'", "_computeCCI"), - ("info.name === 'Williams %R'", "_computeWilliamsR"), - ("info.name === 'Accumulation/Distribution'", "_computeAccumulationDistribution"), - ("info.name === 'Historical Volatility'", "_computeHistoricalVolatility"), - ("type === 'parabolic-sar'", "_computeParabolicSAR"), - ("type === 'macd'", "_computeMACD"), - ("type === 'stochastic'", "_computeStochastic"), - ("type === 'aroon'", "_computeAroon"), - ("type === 'adx'", "_computeADX"), - ("type === 'keltner-channels'", "_computeKeltnerChannels"), - ("type === 'ichimoku'", "_computeIchimoku"), - ) - - @pytest.mark.parametrize("branch,fn", RECOMPUTE_BRANCHES) - def test_recompute_branch_refreshes_series( - self, recompute_body: str, branch: str, fn: str - ) -> None: - assert branch in recompute_body, ( - f"_tvRecomputeIndicatorSeries missing branch for {branch!r}. " - "Without this branch, the indicator won't refresh when bars " - "change (e.g., via datafeed scrollback or interval switch) " - "and will stay frozen at its initial snapshot." - ) - idx = recompute_body.index(branch) - tail = recompute_body[idx : idx + 2500] - assert fn in tail, ( - f"Recompute branch {branch!r} found but never calls {fn}() " - "within the following 2500 chars — did the branch get broken?" - ) - - def test_recompute_branch_for_volume_profile(self, recompute_body: str) -> None: - """Visible-range volume profiles must recompute when the bar set - changes — otherwise scrolling into new data leaves their right-pinned - rows reflecting the old range.""" - assert "type === 'volume-profile-visible'" in recompute_body - assert "_tvRefreshVisibleVolumeProfiles" in recompute_body - - -# ============================================================================= -# Volume Profile compute contract -# ============================================================================= - - -class TestTVChartVolumeProfile: - """Tests for _tvComputeVolumeProfile — the pure function behind VPVR.""" - - @pytest.fixture - def js(self) -> str: - from pywry.assets import get_tvchart_defaults_js - - return get_tvchart_defaults_js() - - def test_vp_compute_function_signature(self, js: str) -> None: - assert "function _tvComputeVolumeProfile(bars, fromIdx, toIdx, opts)" in js - - def test_vp_result_returns_profile_and_metadata(self, js: str) -> None: - fn_start = js.index("function _tvComputeVolumeProfile(") - fn_end = js.index("\nfunction ", fn_start + 1) - body = js[fn_start:fn_end] - for key in ("profile", "minPrice", "maxPrice", "step", "totalVolume"): - assert key in body, f"VP compute result missing expected field '{key}'" - - def test_vp_splits_up_down_volume(self, js: str) -> None: - fn_start = js.index("function _tvComputeVolumeProfile(") - fn_end = js.index("\nfunction ", fn_start + 1) - body = js[fn_start:fn_end] - # Up/down split is what differentiates VPVR from a flat histogram. - assert "upVol" in body and "downVol" in body, ( - "VP compute must split each row into up vs down volume" - ) - - def test_vp_exposes_poc_value_area_helper(self, js: str) -> None: - """A separate helper derives POC and Value Area from the computed profile.""" - assert "function _tvComputePOCAndValueArea(" in js - fn_start = js.index("function _tvComputePOCAndValueArea(") - fn_end = js.index("\nfunction ", fn_start + 1) - body = js[fn_start:fn_end] - for key in ("pocIdx", "vaLowIdx", "vaHighIdx"): - assert key in body, f"POC/VA helper must expose '{key}' so renderer can draw lines" - - def test_vp_refresh_visible_exposed(self, js: str) -> None: - """Visible-range refresh must exist for the recompute path to call it.""" - assert "function _tvRefreshVisibleVolumeProfiles(chartId)" in js - - -# ============================================================================= -# Legend volume removal actually destroys the series + pane -# ============================================================================= - - -class TestTVChartLegendVolumeRemoval: - """Removing volume from the legend must actually remove it from the chart - (issue: previously, clicking Remove only set a legend dataset flag but - left the histogram series and its pane on the chart).""" - - @pytest.fixture - def js(self) -> str: - from pywry.assets import get_tvchart_defaults_js - - return get_tvchart_defaults_js() - - def _fn_or_nested(self, js: str, name: str) -> str: - """Extract a function body — works for nested ``function X()`` too.""" - idx = js.index(f"function {name}(") - depth = 0 - i = js.index("{", idx) - n = len(js) - while i < n: - ch = js[i] - if ch == "{": - depth += 1 - elif ch == "}": - depth -= 1 - if depth == 0: - return js[idx : i + 1] - i += 1 - raise RuntimeError(f"Could not find end of {name}") - - def test_disable_volume_removes_series(self, js: str) -> None: - body = self._fn_or_nested(js, "_legendDisableVolume") - assert "entry.chart.removeSeries(volSeries)" in body, ( - "Remove-volume must actually call chart.removeSeries" - ) - assert "delete entry.volumeMap.main" in body, "Remove-volume must clear the volumeMap entry" - - def test_disable_volume_removes_pane(self, js: str) -> None: - body = self._fn_or_nested(js, "_legendDisableVolume") - assert "chart.removePane(removedPane)" in body, ( - "Remove-volume must collapse the now-empty pane, not leave dead space" - ) - - def test_disable_volume_reindexes_panes(self, js: str) -> None: - body = self._fn_or_nested(js, "_legendDisableVolume") - # When pane N is removed, LWC reindexes panes > N down by 1. We must - # mirror that for our bookkeeping on _activeIndicators and _volumePaneBySeries. - assert ".paneIndex -= 1" in body - assert "_volumePaneBySeries" in body - - def test_enable_volume_rebuilds_series(self, js: str) -> None: - body = self._fn_or_nested(js, "_legendEnableVolume") - assert "_tvAddSeriesCompat(entry.chart, 'Histogram'" in body, ( - "Restore-volume must rebuild the histogram series via the same " - "path used for initial creation" - ) - assert "_tvExtractVolumeFromBars" in body, ( - "Restore-volume must re-extract volume from the stored raw bars" - ) - - -# ============================================================================= -# Theme CSS variables — every new VP / indicator color var is defined -# ============================================================================= - - -class TestTVChartThemeVariables: - """The tvchart.css stylesheet must define every CSS variable that the - frontend JS consumes, in both dark and light themes (otherwise colors - silently fall back to whatever the browser decides).""" - - @pytest.fixture - def css(self) -> str: - from pathlib import Path - - import pywry - - return (Path(pywry.__file__).parent / "frontend" / "style" / "tvchart.css").read_text( - encoding="utf-8" - ) - - VP_VARS = ( - "--pywry-tvchart-vp-up", - "--pywry-tvchart-vp-down", - "--pywry-tvchart-vp-va-up", - "--pywry-tvchart-vp-va-down", - "--pywry-tvchart-vp-poc", - ) - - INDICATOR_PALETTE_VARS = ( - "--pywry-tvchart-ind-primary", - "--pywry-tvchart-ind-secondary", - "--pywry-tvchart-ind-tertiary", - "--pywry-tvchart-ind-positive", - "--pywry-tvchart-ind-negative", - "--pywry-tvchart-ind-positive-dim", - "--pywry-tvchart-ind-negative-dim", - ) - - @pytest.mark.parametrize("var", VP_VARS + INDICATOR_PALETTE_VARS) - def test_var_defined_at_least_twice(self, css: str, var: str) -> None: - """Each var must appear in both the dark (root) and light theme blocks.""" - count = css.count(var + ":") - assert count >= 2, ( - f"CSS var {var} defined only {count} time(s); expected at least 2 " - "(one for dark theme, one for light)." - ) - - -# ============================================================================= -# MCP tool definition tests -# ============================================================================= - - -class TestMCPToolDefinition: - """Verify show_tvchart MCP tool is registered.""" - - def test_tool_schema_exists(self): - from pywry.mcp.tools import get_tools - - names = [t.name for t in get_tools()] - assert "show_tvchart" in names - - def test_tool_schema_has_data_json(self): - from pywry.mcp.tools import get_tools - - tool = next(t for t in get_tools() if t.name == "show_tvchart") - props = tool.inputSchema["properties"] - assert "data_json" in props - - def test_handler_registered(self): - from pywry.mcp.handlers import _HANDLERS - - assert "show_tvchart" in _HANDLERS - - -# --------------------------------------------------------------------------- -# Alternative chart factories: createOptionsChart + createYieldCurveChart -# --------------------------------------------------------------------------- - - -class TestTVChartSpecialtyChartKinds: - """Contract checks for the two non-temporal LWC chart factories. - - Lightweight Charts 5.x exposes three factories: - * createChart — time X axis (default) - * createOptionsChart — numeric price / strike X axis - * createYieldCurveChart — tenor-in-months X axis - - PyWry routes these via ``payload.chartKind`` in - ``PYWRY_TVCHART_CREATE``. These tests lock down both the dispatch - logic AND the option builders so future refactors can't silently - break either branch. - """ - - @pytest.fixture - def tvchart_defaults_js(self) -> str: - from pywry.assets import get_tvchart_defaults_js - - return get_tvchart_defaults_js() - - def test_bundle_ships_all_three_builders(self, tvchart_defaults_js: str): - assert "function _tvBuildChartOptions(" in tvchart_defaults_js - assert "function _tvBuildPriceChartOptions(" in tvchart_defaults_js - assert "function _tvBuildYieldCurveChartOptions(" in tvchart_defaults_js - - def test_price_builder_inherits_base_defaults(self, tvchart_defaults_js: str): - src = tvchart_defaults_js - start = src.index("function _tvBuildPriceChartOptions(") - body = TestTVChartFrontendStateContracts._extract_braced(src, start) - assert "_tvBuildChartOptions(null, theme)" in body, ( - "price chart options must inherit the base PyWry defaults so " - "palette / interaction / scales stay consistent across factories" - ) - - def test_yield_curve_builder_seeds_yield_curve_options(self, tvchart_defaults_js: str): - src = tvchart_defaults_js - start = src.index("function _tvBuildYieldCurveChartOptions(") - body = TestTVChartFrontendStateContracts._extract_braced(src, start) - assert "_tvBuildChartOptions(null, theme)" in body - assert "yieldCurve" in body - assert "baseResolution" in body - assert "minimumTimeRange" in body - assert "startTimeRange" in body - - def test_yield_curve_builder_ignores_whitespace_indices(self, tvchart_defaults_js: str): - """The crosshair must snap to real tenors — a yield curve has - irregular whitespace between 2Y and 5Y, 5Y and 10Y, etc.""" - src = tvchart_defaults_js - start = src.index("function _tvBuildYieldCurveChartOptions(") - body = TestTVChartFrontendStateContracts._extract_braced(src, start) - assert "ignoreWhitespaceIndices = true" in body - - def test_create_dispatches_to_price_factory(self, tvchart_defaults_js: str): - body = TestTVChartFrontendStateContracts()._create_body(tvchart_defaults_js) - assert "LightweightCharts.createOptionsChart(container, chartOptions)" in body - assert "chartKind === 'price'" in body - - def test_create_dispatches_to_yield_curve_factory(self, tvchart_defaults_js: str): - body = TestTVChartFrontendStateContracts()._create_body(tvchart_defaults_js) - assert "LightweightCharts.createYieldCurveChart(container, chartOptions)" in body - assert "yield-curve" in body - - def test_create_default_falls_back_to_create_chart(self, tvchart_defaults_js: str): - body = TestTVChartFrontendStateContracts()._create_body(tvchart_defaults_js) - assert "LightweightCharts.createChart(container, chartOptions)" in body - - def test_volume_auto_enable_gated_on_default_chart_kind(self, tvchart_defaults_js: str): - """Auto-volume on price / yield-curve charts would histogram - by strike / tenor which is meaningless — gate it off.""" - body = TestTVChartFrontendStateContracts()._create_body(tvchart_defaults_js) - assert "enableVolume !== false && chartKind === 'default'" in body - - def test_time_range_tabs_gated_on_default_chart_kind(self, tvchart_defaults_js: str): - """'1D / 5D / 1Y / ...' tabs only make sense for time-axis - charts. Skip the lookup on specialty kinds.""" - body = TestTVChartFrontendStateContracts()._create_body(tvchart_defaults_js) - # Guard is an inline `if (chartKind === 'default')` ahead of the - # `.pywry-tab-active[data-target-interval]` query. - idx_guard = body.find("chartKind === 'default'") - idx_tab_query = body.find(".pywry-tab-active[data-target-interval]") - assert idx_guard != -1 and idx_tab_query != -1 - assert idx_guard < idx_tab_query, ( - "the chartKind guard must appear BEFORE the time-range tab " - "lookup so non-default charts skip the whole block" - ) - - -class TestTVChartChartKindConfig: - """Python typed surface for the chartKind selector. - - Locks in the TVChartConfig literal + the to_payload shape that the - frontend consumes. - """ - - def test_config_default_is_time_axis(self): - from pywry.tvchart.config import TVChartConfig - - cfg = TVChartConfig() - assert cfg.chart_kind == "default" - assert cfg.yield_curve is None - - def test_config_accepts_price_kind(self): - from pywry.tvchart.config import TVChartConfig - - cfg = TVChartConfig(chart_kind="price") - assert cfg.chart_kind == "price" - - def test_config_accepts_yield_curve_kind(self): - from pywry.tvchart.config import TVChartConfig - - cfg = TVChartConfig(chart_kind="yield-curve") - assert cfg.chart_kind == "yield-curve" - - def test_config_rejects_unknown_kind(self): - import pydantic - - from pywry.tvchart.config import TVChartConfig - - with pytest.raises(pydantic.ValidationError): - TVChartConfig(chart_kind="candlestick") # type: ignore - - def test_to_payload_exposes_chart_kind_alongside_options(self): - from pywry.tvchart.config import TVChartConfig - - cfg = TVChartConfig(chart_kind="price") - payload = cfg.to_payload() - assert payload["chartKind"] == "price" - assert isinstance(payload["chartOptions"], dict) - - def test_to_payload_forwards_yield_curve_options(self): - from pywry.tvchart.config import TVChartConfig - - cfg = TVChartConfig( - chart_kind="yield-curve", - yield_curve={ - "baseResolution": 1, - "minimumTimeRange": 360, - "startTimeRange": 0, - }, - ) - payload = cfg.to_payload() - assert payload["chartKind"] == "yield-curve" - assert payload["chartOptions"]["yieldCurve"]["minimumTimeRange"] == 360 - - def test_to_chart_options_skips_yield_curve_when_unset(self): - from pywry.tvchart.config import TVChartConfig - - cfg = TVChartConfig(chart_kind="yield-curve") - opts = cfg.to_chart_options() - assert "yieldCurve" not in opts, ( - "yield_curve is optional — don't ship an empty block that the " - "frontend would treat as a wipe of the LWC defaults" - ) - - -class TestTVChartSpecialtyInlinePayload: - """The inline (notebook) path must carry chart_kind into the JSON - payload that gets dumped into the PyWryTVChartWidget's chart_config - traitlet — that's the only channel the frontend reads.""" - - def test_inline_payload_carries_chart_kind(self): - import inspect - - from pywry import inline as pywry_inline - - src = inspect.getsource(pywry_inline.show_tvchart) - assert '"chartKind": chart_kind' in src, ( - "chart_kind must land in the JSON config_payload so the " - "frontend can route to createOptionsChart / " - "createYieldCurveChart" - ) - - def test_inline_show_tvchart_accepts_chart_kind(self): - import inspect - - from pywry import inline as pywry_inline - - sig = inspect.signature(pywry_inline.show_tvchart) - assert "chart_kind" in sig.parameters - assert sig.parameters["chart_kind"].default == "default" - - def test_app_show_tvchart_accepts_chart_kind(self): - import inspect - - from pywry.app import PyWry - - sig = inspect.signature(PyWry.show_tvchart) - assert "chart_kind" in sig.parameters - assert "yield_curve" in sig.parameters - assert sig.parameters["chart_kind"].default == "default" - - def test_specialty_demo_cells_in_notebook(self): - """The TVChart demo notebook must include runnable cells for - both alternative chart kinds — keeps the documented example in - sync with the public chart_kind / yield_curve API surface.""" - import ast - import json - - from pathlib import Path - - nb_path = Path(__file__).resolve().parent.parent / "examples" / "pywry_demo_tvchart.ipynb" - if not nb_path.exists(): - pytest.skip("demo notebook not bundled in this source tree") - nb = json.loads(nb_path.read_text(encoding="utf-8")) - code_cells = [ - "".join(c.get("source", [])) - for c in nb.get("cells", []) - if c.get("cell_type") == "code" - ] - assert any( - 'chart_kind="yield-curve"' in src and "yield_curve" in src for src in code_cells - ), "notebook missing a yield-curve chart cell" - assert any('chart_kind="price"' in src for src in code_cells), ( - "notebook missing a price-axis (options payoff) chart cell" - ) - # Every code cell must still parse as valid Python so stale - # snippets break this test loudly instead of silently rotting. - for src in code_cells: - ast.parse(src) diff --git a/pywry/tests/test_tvchart_e2e.py b/pywry/tests/test_tvchart_e2e.py index 9a78d21..d494a5e 100644 --- a/pywry/tests/test_tvchart_e2e.py +++ b/pywry/tests/test_tvchart_e2e.py @@ -1457,19 +1457,13 @@ def test_05_server_remove_emits_event(self, server_storage_chart: dict[str, Any] # Inline Mode (synthetic data, no UDF) # ============================================================================ -try: - from pywry.inline import HAS_FASTAPI -except ImportError: - HAS_FASTAPI = False - - def _http_get(url: str, timeout: float = 5.0) -> str: req = urllib.request.Request(url) with urllib.request.urlopen(req, timeout=timeout) as resp: return resp.read().decode("utf-8") -@pytest.mark.skipif(not HAS_FASTAPI, reason="FastAPI not installed") + class TestTVChartInline: """Inline rendering path.""" @@ -1531,7 +1525,7 @@ def _wait_for_port_release(port: int, timeout: float = 5.0) -> bool: return False -@pytest.mark.skipif(not HAS_FASTAPI, reason="FastAPI not installed") + class TestTVChartBrowser: """Browser mode rendering path.""" diff --git a/pywry/tests/test_types.py b/pywry/tests/test_types.py index b9daa56..56da21d 100644 --- a/pywry/tests/test_types.py +++ b/pywry/tests/test_types.py @@ -9,21 +9,29 @@ import pytest from pywry.types import ( + CheckMenuItemConfig, Cookie, CursorIcon, Effect, Effects, EffectState, + IconMenuItemConfig, LogicalPosition, LogicalSize, + MenuConfig, + MenuItemConfig, Monitor, PhysicalPosition, PhysicalSize, + PredefinedMenuItemConfig, + PredefinedMenuItemKind, ProgressBarState, ProgressBarStatus, SameSite, + SubmenuConfig, Theme, TitleBarStyle, + TrayIconConfig, UserAttentionType, serialize_effects, serialize_position, @@ -474,3 +482,205 @@ def test_position_hashable(self) -> None: PhysicalPosition(1, 1), } assert len(positions) == 2 + + +def _click(*_args, **_kwargs): + return None + + +class TestMenuItemConfig: + def test_handler_required(self): + with pytest.raises(TypeError, match="requires a handler"): + MenuItemConfig(id="x", text="X", handler=None) + + def test_to_dict_basic(self): + item = MenuItemConfig(id="x", text="X", handler=_click) + d = item.to_dict() + assert d["kind"] == "item" + assert d["id"] == "x" + assert d["text"] == "X" + assert d["enabled"] is True + assert "accelerator" not in d + + def test_to_dict_with_accelerator(self): + item = MenuItemConfig(id="x", text="X", handler=_click, accelerator="Ctrl+S") + d = item.to_dict() + assert d["accelerator"] == "Ctrl+S" + + +class TestCheckMenuItemConfig: + def test_handler_required(self): + with pytest.raises(TypeError, match="requires a handler"): + CheckMenuItemConfig(id="x", text="X", handler=None) + + def test_to_dict_basic(self): + item = CheckMenuItemConfig(id="x", text="X", handler=_click, checked=True) + d = item.to_dict() + assert d["kind"] == "check" + assert d["checked"] is True + + def test_to_dict_accelerator(self): + item = CheckMenuItemConfig(id="x", text="X", handler=_click, accelerator="Ctrl+T") + d = item.to_dict() + assert d["accelerator"] == "Ctrl+T" + + +class TestIconMenuItemConfig: + def test_handler_required(self): + with pytest.raises(TypeError, match="requires a handler"): + IconMenuItemConfig(id="x", text="X", handler=None) + + def test_to_dict_with_icon_bytes(self): + item = IconMenuItemConfig(id="x", text="X", handler=_click, icon=b"\x00\x01\x02\x03") + d = item.to_dict() + assert d["kind"] == "icon" + assert "icon" in d + assert d["icon_width"] == 16 + assert d["icon_height"] == 16 + + def test_to_dict_native_icon(self): + item = IconMenuItemConfig(id="x", text="X", handler=_click, native_icon="Add") + d = item.to_dict() + assert d["native_icon"] == "Add" + + def test_to_dict_accelerator(self): + item = IconMenuItemConfig(id="x", text="X", handler=_click, accelerator="Ctrl+I") + d = item.to_dict() + assert d["accelerator"] == "Ctrl+I" + + def test_to_dict_no_icon(self): + item = IconMenuItemConfig(id="x", text="X", handler=_click) + d = item.to_dict() + assert "icon" not in d + assert "icon_width" not in d + + +class TestPredefinedMenuItemConfig: + def test_to_dict_basic(self): + item = PredefinedMenuItemConfig(kind_name=PredefinedMenuItemKind.SEPARATOR) + d = item.to_dict() + assert d["kind"] == "predefined" + assert d["kind_name"] == "separator" + assert "text" not in d + + def test_to_dict_with_text(self): + item = PredefinedMenuItemConfig(kind_name=PredefinedMenuItemKind.QUIT, text="Exit") + d = item.to_dict() + assert d["text"] == "Exit" + + +class TestSubmenuConfig: + def test_to_dict_empty(self): + s = SubmenuConfig(id="s", text="Sub") + d = s.to_dict() + assert d["kind"] == "submenu" + assert d["id"] == "s" + assert "items" not in d + + def test_to_dict_with_items(self): + s = SubmenuConfig( + id="s", + text="Sub", + items=[ + MenuItemConfig(id="i1", text="One", handler=_click), + MenuItemConfig(id="i2", text="Two", handler=_click), + ], + ) + d = s.to_dict() + assert "items" in d + assert len(d["items"]) == 2 + + def test_collect_handlers_includes_descendants(self): + s = SubmenuConfig( + id="s", + text="Sub", + items=[ + MenuItemConfig(id="i1", text="A", handler=_click), + CheckMenuItemConfig(id="i2", text="B", handler=_click), + ], + ) + handlers = s.collect_handlers() + assert "i1" in handlers + assert "i2" in handlers + + +class TestMenuConfig: + def test_to_dict(self): + cfg = MenuConfig( + id="main", + items=[MenuItemConfig(id="x", text="X", handler=_click)], + ) + d = cfg.to_dict() + assert d["id"] == "main" + assert isinstance(d["items"], list) + + def test_collect_handlers(self): + cfg = MenuConfig( + id="main", + items=[ + MenuItemConfig(id="x", text="X", handler=_click), + SubmenuConfig( + id="sub", + text="Sub", + items=[MenuItemConfig(id="nested", text="N", handler=_click)], + ), + ], + ) + handlers = cfg.collect_handlers() + assert "x" in handlers + assert "nested" in handlers + + def test_round_trip_from_dict(self): + original = MenuConfig( + id="main", + items=[ + MenuItemConfig(id="x", text="X", handler=_click, accelerator="Ctrl+X"), + CheckMenuItemConfig(id="c", text="C", handler=_click, checked=True), + IconMenuItemConfig(id="i", text="I", handler=_click, icon=b"\x00\xff"), + PredefinedMenuItemConfig(kind_name=PredefinedMenuItemKind.SEPARATOR), + SubmenuConfig( + id="sub", + text="Sub", + items=[MenuItemConfig(id="n", text="N", handler=_click)], + ), + ], + ) + data = original.to_dict() + restored = MenuConfig.from_dict(data) + assert restored.id == "main" + assert len(restored.items) == 5 + + def test_from_dict_unknown_kind_raises(self): + with pytest.raises(ValueError, match="Unknown menu item kind"): + MenuConfig.from_dict({"id": "m", "items": [{"kind": "garbage", "id": "x", "text": "X"}]}) + + +class TestTrayIconConfig: + def test_minimal_to_dict(self): + cfg = TrayIconConfig(id="t1") + d = cfg.to_dict() + assert d["id"] == "t1" + assert d["menu_on_left_click"] is True + assert "tooltip" not in d + assert "title" not in d + assert "icon" not in d + assert "menu" not in d + + def test_full_to_dict(self): + menu = MenuConfig(id="m", items=[MenuItemConfig(id="x", text="X", handler=_click)]) + cfg = TrayIconConfig( + id="t1", + tooltip="Hover", + title="MyApp", + icon=b"\x00\xff\x00\xff", + menu=menu, + menu_on_left_click=False, + ) + d = cfg.to_dict() + assert d["tooltip"] == "Hover" + assert d["title"] == "MyApp" + assert "icon" in d + assert d["icon_width"] == 32 + assert d["icon_height"] == 32 + assert d["menu_on_left_click"] is False + assert d["menu"]["id"] == "m" diff --git a/pywry/tests/test_udf_adapter.py b/pywry/tests/test_udf_adapter.py deleted file mode 100644 index d5dad36..0000000 --- a/pywry/tests/test_udf_adapter.py +++ /dev/null @@ -1,562 +0,0 @@ -"""Unit tests for pywry.tvchart.udf (UDF adapter).""" - -from __future__ import annotations - -from typing import Any - -import httpx -import pytest - -from pywry.tvchart.udf import ( - QuoteData, - UDFAdapter, - from_udf_resolution, - parse_udf_columns, - to_udf_resolution, -) - - -# --------------------------------------------------------------------------- -# Resolution mapping -# --------------------------------------------------------------------------- - - -class TestResolutionMapping: - """Tests for to_udf_resolution / from_udf_resolution.""" - - @pytest.mark.parametrize( - ("canonical", "udf"), - [ - ("1m", "1"), - ("5m", "5"), - ("15m", "15"), - ("30m", "30"), - ("1h", "60"), - ("2h", "120"), - ("4h", "240"), - ("1d", "D"), - ("1D", "D"), - ("1w", "W"), - ("1W", "W"), - ("1M", "M"), - ("3M", "3M"), - ], - ) - def test_canonical_to_udf(self, canonical: str, udf: str) -> None: - assert to_udf_resolution(canonical) == udf - - @pytest.mark.parametrize( - ("udf", "canonical"), - [ - ("1", "1m"), - ("5", "5m"), - ("60", "1h"), - ("D", "1d"), - ("1D", "1d"), - ("W", "1w"), - ("M", "1M"), - ("1M", "1M"), - ], - ) - def test_udf_to_canonical(self, udf: str, canonical: str) -> None: - assert from_udf_resolution(udf) == canonical - - def test_passthrough_unknown(self) -> None: - assert to_udf_resolution("UNKNOWN") == "UNKNOWN" - assert from_udf_resolution("UNKNOWN") == "UNKNOWN" - - -# --------------------------------------------------------------------------- -# UDF columnar parsing -# --------------------------------------------------------------------------- - - -class TestParseUDFColumns: - """Tests for parse_udf_columns.""" - - def test_basic_table(self) -> None: - data = { - "t": [100, 200, 300], - "c": [10.0, 20.0, 30.0], - "v": [1000, 2000, 3000], - } - rows = parse_udf_columns(data) - assert len(rows) == 3 - assert rows[0] == {"t": 100, "c": 10.0, "v": 1000} - assert rows[2] == {"t": 300, "c": 30.0, "v": 3000} - - def test_scalar_broadcast(self) -> None: - data = { - "symbol": ["AAPL", "MSFT"], - "exchange": "NASDAQ", - "pricescale": 100, - } - rows = parse_udf_columns(data) - assert len(rows) == 2 - assert rows[0] == {"symbol": "AAPL", "exchange": "NASDAQ", "pricescale": 100} - assert rows[1] == {"symbol": "MSFT", "exchange": "NASDAQ", "pricescale": 100} - - def test_empty_data(self) -> None: - assert parse_udf_columns({}) == [] - assert parse_udf_columns({"scalar": 42}) == [] - - def test_explicit_count(self) -> None: - data = {"val": 99} - rows = parse_udf_columns(data, count=3) - assert len(rows) == 3 - assert all(r["val"] == 99 for r in rows) - - def test_mixed_scalar_and_list(self) -> None: - data = { - "id": [1, 2], - "time": [1000, 2000], - "color": "red", - "label": ["A", "B"], - } - rows = parse_udf_columns(data) - assert len(rows) == 2 - assert rows[0] == {"id": 1, "time": 1000, "color": "red", "label": "A"} - assert rows[1] == {"id": 2, "time": 2000, "color": "red", "label": "B"} - - -# --------------------------------------------------------------------------- -# QuoteData -# --------------------------------------------------------------------------- - - -class TestQuoteData: - """Tests for the QuoteData value object.""" - - def test_basic_fields(self) -> None: - q = QuoteData( - n="NYSE:AA", - s="ok", - v={ - "ch": 0.16, - "chp": 0.98, - "short_name": "AA", - "exchange": "NYSE", - "description": "Alcoa Inc.", - "lp": 16.57, - "ask": 16.58, - "bid": 16.57, - "open_price": 16.25, - "high_price": 16.60, - "low_price": 16.25, - "prev_close_price": 16.41, - "volume": 4029041, - }, - ) - assert q.symbol == "NYSE:AA" - assert q.status == "ok" - assert q.last_price == 16.57 - assert q.change == 0.16 - assert q.change_percent == 0.98 - assert q.short_name == "AA" - assert q.volume == 4029041 - - def test_format_ticker_html_positive(self) -> None: - q = QuoteData( - n="AAPL", - s="ok", - v={"short_name": "AAPL", "lp": 186.25, "ch": 1.50, "chp": 0.81}, - ) - html = q.format_ticker_html() - assert "AAPL" in html - assert "186.25" in html - assert "+1.50" in html - assert "+0.81%" in html - assert "pywry-success" in html - - def test_format_ticker_html_negative(self) -> None: - q = QuoteData( - n="MSFT", - s="ok", - v={"short_name": "MSFT", "lp": 415.00, "ch": -2.50, "chp": -0.60}, - ) - html = q.format_ticker_html() - assert "MSFT" in html - assert "-2.50" in html - assert "pywry-error" in html - - def test_format_ticker_html_no_change(self) -> None: - q = QuoteData(n="X", s="ok", v={"short_name": "X", "lp": 10.0}) - html = q.format_ticker_html(show_change=True) - assert "10.00" in html - # No change data → just price, no color spans - assert "span" not in html - - def test_empty_quote(self) -> None: - q = QuoteData(n="", s="error", v={}, errmsg="not found") - assert q.status == "error" - assert q.error == "not found" - assert q.last_price is None - - -# --------------------------------------------------------------------------- -# UDFAdapter — unit tests with mocked HTTP -# --------------------------------------------------------------------------- - - -class TestUDFAdapterBarParsing: - """Test the bar-parsing logic in get_bars.""" - - @pytest.fixture() - def adapter(self) -> UDFAdapter: - return UDFAdapter("https://example.com") - - @pytest.mark.asyncio() - async def test_parse_history_ok( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_response = { - "s": "ok", - "t": [1386493512, 1386493572], - "c": [42.1, 43.4], - "o": [41.0, 42.9], - "h": [43.0, 44.1], - "l": [40.4, 42.1], - "v": [12000, 18500], - } - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw_response) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - result = await adapter.get_bars("AAPL", "D", 1386493512, 1386493999) - - assert result["status"] == "ok" - assert len(result["bars"]) == 2 - bar0 = result["bars"][0] - assert bar0["time"] == 1386493512 # Unix seconds (not ms) - assert bar0["open"] == 41.0 - assert bar0["close"] == 42.1 - assert bar0["volume"] == 12000 - - @pytest.mark.asyncio() - async def test_parse_history_no_data( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_response = {"s": "no_data", "nextTime": 1428001140000} - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw_response) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - result = await adapter.get_bars("AAPL", "1", 100, 200) - assert result["status"] == "no_data" - assert result["no_data"] is True - assert result["next_time"] == 1428001140000 - - @pytest.mark.asyncio() - async def test_parse_history_ok_with_no_data_flag( - self, - adapter: UDFAdapter, - monkeypatch: pytest.MonkeyPatch, - ) -> None: - """UDF servers can return bars AND noData=true to signal oldest data.""" - raw_response = { - "s": "ok", - "t": [1386493512], - "c": [42.1], - "o": [41.0], - "h": [43.0], - "l": [40.4], - "v": [12000], - "noData": True, - } - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw_response) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - result = await adapter.get_bars("AAPL", "D", 1386493512, 1386493999) - - assert result["status"] == "ok" - assert len(result["bars"]) == 1 - assert result["no_data"] is True # Server's noData flag preserved - assert result["bars"][0]["time"] == 1386493512 - - @pytest.mark.asyncio() - async def test_parse_history_error( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_response = {"s": "error", "errmsg": "Invalid symbol"} - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw_response) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - result = await adapter.get_bars("INVALID", "D", 100, 200) - assert result["status"] == "error" - assert result["error"] == "Invalid symbol" - - -class TestUDFAdapterConfig: - """Test config parsing from /config.""" - - @pytest.fixture() - def adapter(self) -> UDFAdapter: - return UDFAdapter("https://example.com") - - @pytest.mark.asyncio() - async def test_parse_config(self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch) -> None: - raw_config = { - "supports_search": True, - "supports_group_request": False, - "supports_marks": True, - "supports_timescale_marks": True, - "supports_time": True, - "exchanges": [ - {"value": "", "name": "All Exchanges", "desc": ""}, - {"value": "XETRA", "name": "XETRA", "desc": "XETRA"}, - ], - "symbols_types": [{"name": "All types", "value": ""}], - "supported_resolutions": ["D", "2D", "3D", "W", "3W", "M", "6M"], - } - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw_config) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - config = await adapter._fetch_config() - - assert adapter._supports_search is True - assert adapter._supports_marks is True - assert adapter._supports_timescale_marks is True - assert adapter._supports_time is True - assert config["supported_resolutions"] == ["D", "2D", "3D", "W", "3W", "M", "6M"] - assert len(config["exchanges"]) == 2 - # supports_search and supports_group_request must be forwarded to frontend - assert config["supports_search"] is True - - -class TestUDFAdapterResolve: - """Test symbol resolution and key mapping.""" - - @pytest.fixture() - def adapter(self) -> UDFAdapter: - return UDFAdapter("https://example.com") - - @pytest.mark.asyncio() - async def test_resolve_maps_hyphen_keys( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_symbol = { - "name": "AAPL", - "description": "Apple Inc", - "exchange": "NASDAQ", - "type": "stock", - "session-regular": "0930-1600", - "has-intraday": True, - "has-daily": True, - "has-weekly-and-monthly": True, - "supported-resolutions": ["1", "5", "15", "30", "60", "D", "W", "M"], - "minmovement": 1, - "pricescale": 100, - "timezone": "America/New_York", - } - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw_symbol) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - info = await adapter.resolve_symbol("AAPL") - - assert info["name"] == "AAPL" - assert info["session"] == "0930-1600" - assert info["has_intraday"] is True - assert info["has_daily"] is True - assert info["has_weekly_and_monthly"] is True - assert info["supported_resolutions"] == ["1", "5", "15", "30", "60", "D", "W", "M"] - assert info["minmov"] == 1 - - -class TestUDFAdapterSearch: - """Test symbol search.""" - - @pytest.fixture() - def adapter(self) -> UDFAdapter: - return UDFAdapter("https://example.com") - - @pytest.mark.asyncio() - async def test_search_returns_items( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_results = [ - { - "symbol": "AAPL", - "full_name": "NASDAQ:AAPL", - "description": "Apple Inc", - "exchange": "NASDAQ", - "type": "stock", - }, - { - "symbol": "AA", - "full_name": "NYSE:AA", - "description": "Alcoa", - "exchange": "NYSE", - "type": "stock", - }, - ] - - async def mock_get(path: str = "", params: dict | None = None) -> Any: - return _MockResponse(200, raw_results) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - items = await adapter.search_symbols("AA", limit=10) - assert len(items) == 2 - assert items[0]["symbol"] == "AAPL" - - -class TestUDFAdapterMarks: - """Test marks parsing.""" - - @pytest.fixture() - def adapter(self) -> UDFAdapter: - return UDFAdapter("https://example.com") - - @pytest.mark.asyncio() - async def test_marks_columnar( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_marks = { - "id": [1, 2], - "time": [1000, 2000], - "color": "red", - "text": ["Mark 1", "Mark 2"], - "label": ["A", "B"], - "labelFontColor": "white", - "minSize": 14, - } - - async def mock_get(path: str = "", params: dict | None = None) -> Any: - return _MockResponse(200, raw_marks) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - marks = await adapter.get_marks("AAPL", 0, 9999, "D") - assert len(marks) == 2 - assert marks[0]["id"] == 1 - assert marks[0]["color"] == "red" - assert marks[1]["text"] == "Mark 2" - - @pytest.mark.asyncio() - async def test_marks_array_passthrough( - self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch - ) -> None: - raw_marks = [ - {"id": 1, "time": 1000, "color": "blue", "text": "A"}, - {"id": 2, "time": 2000, "color": "green", "text": "B"}, - ] - - async def mock_get(path: str = "", params: dict | None = None) -> Any: - return _MockResponse(200, raw_marks) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - marks = await adapter.get_marks("AAPL", 0, 9999, "D") - assert len(marks) == 2 - assert marks[0]["color"] == "blue" - - -class TestUDFAdapterQuotes: - """Test quotes endpoint parsing.""" - - @pytest.fixture() - def adapter(self) -> UDFAdapter: - return UDFAdapter("https://example.com") - - @pytest.mark.asyncio() - async def test_quotes_ok(self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch) -> None: - raw = { - "s": "ok", - "d": [ - { - "s": "ok", - "n": "NYSE:AA", - "v": { - "ch": 0.16, - "chp": 0.98, - "short_name": "AA", - "exchange": "NYSE", - "description": "Alcoa Inc.", - "lp": 16.57, - "volume": 4029041, - }, - }, - ], - } - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - quotes = await adapter._get_quotes(["NYSE:AA"]) - assert len(quotes) == 1 - assert quotes[0].symbol == "NYSE:AA" - assert quotes[0].last_price == 16.57 - assert quotes[0].change == 0.16 - - @pytest.mark.asyncio() - async def test_quotes_error(self, adapter: UDFAdapter, monkeypatch: pytest.MonkeyPatch) -> None: - raw = {"s": "error", "errmsg": "Bad request"} - - async def mock_get(path: str, params: dict | None = None) -> Any: - return _MockResponse(200, raw) - - monkeypatch.setattr(adapter._client, "get", mock_get) - - quotes = await adapter._get_quotes(["BAD"]) - assert quotes == [] - - -class TestUDFAdapterLifecycle: - """Test adapter lifecycle methods.""" - - def test_close_idempotent(self) -> None: - adapter = UDFAdapter("https://example.com") - adapter.close() - assert adapter._closed is True - # Second close should not raise - adapter.close() - - def test_properties_before_connect(self) -> None: - adapter = UDFAdapter("https://example.com") - assert adapter.config is None - assert adapter.supports_marks is False - assert adapter.supports_time is False - assert adapter.supports_search is True - - -# --------------------------------------------------------------------------- -# Mock helpers -# --------------------------------------------------------------------------- - - -class _MockResponse: - """Minimal mock for httpx.Response.""" - - def __init__(self, status_code: int, json_data: Any) -> None: - self.status_code = status_code - self._json = json_data - self.text = str(json_data) - - def json(self) -> Any: - return self._json - - def raise_for_status(self) -> None: - if self.status_code >= 400: - raise httpx.HTTPStatusError( - f"HTTP {self.status_code}", - request=httpx.Request("GET", "https://example.com"), - response=self, # type: ignore - ) diff --git a/pywry/tests/test_watcher.py b/pywry/tests/test_watcher.py index 4771502..8351965 100644 --- a/pywry/tests/test_watcher.py +++ b/pywry/tests/test_watcher.py @@ -6,8 +6,10 @@ - FileWatcher class with mocked Observer and Timer - Debounce functionality - Global watcher functions +- _WatchHandler event forwarding +- All branches required for 100% coverage -All tests use mocks for file system and threading operations. +All tests use mocks for file system and threading operations where appropriate. """ from __future__ import annotations @@ -18,15 +20,47 @@ from pathlib import Path from unittest.mock import MagicMock, patch +import pytest + from pywry.watcher import ( FileWatcher, WatchedFile, WindowDebouncer, + _WatchHandler, get_file_watcher, stop_file_watcher, ) +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture +def css_file(tmp_path: Path) -> Path: + """Create a CSS file in tmp_path.""" + f = tmp_path / "style.css" + f.write_text("body {}") + return f + + +@pytest.fixture +def watcher() -> FileWatcher: + """Create a fresh FileWatcher with default settings.""" + return FileWatcher() + + +@pytest.fixture +def mocked_observer(): + """Patch pywry.watcher.Observer and yield the (class, instance) pair.""" + with patch("pywry.watcher.Observer") as mock_observer_class: + mock_observer = MagicMock() + mock_observer.emitters = [] + mock_observer_class.return_value = mock_observer + yield mock_observer_class, mock_observer + + # ============================================================================= # WatchedFile Tests # ============================================================================= @@ -36,7 +70,7 @@ class TestWatchedFile: """Test WatchedFile dataclass.""" def test_creation(self) -> None: - """Test creating a WatchedFile.""" + """Test creating a WatchedFile with required fields.""" callback = MagicMock() watched = WatchedFile( path=Path("/test/file.txt"), @@ -48,15 +82,6 @@ def test_creation(self) -> None: assert watched.label == "window1" assert watched.last_triggered == 0.0 - def test_last_triggered_default(self) -> None: - """Test last_triggered defaults to 0.0.""" - watched = WatchedFile( - path=Path("/test/file.txt"), - callback=MagicMock(), - label="window1", - ) - assert watched.last_triggered == 0.0 - def test_last_triggered_custom(self) -> None: """Test last_triggered with custom value.""" watched = WatchedFile( @@ -91,7 +116,7 @@ def test_pending_paths(self) -> None: assert len(debouncer.pending_paths) == 2 def test_lock_is_usable(self) -> None: - """Test that the lock is usable.""" + """Test that the lock is usable as a context manager.""" debouncer = WindowDebouncer() with debouncer.lock: debouncer.pending_paths.add(Path("/test/file.txt")) @@ -99,12 +124,12 @@ def test_lock_is_usable(self) -> None: # ============================================================================= -# FileWatcher Initialization Tests +# FileWatcher Initialization & Debounce Property Tests # ============================================================================= class TestFileWatcherInit: - """Test FileWatcher initialization.""" + """Test FileWatcher initialization and debounce_ms property.""" def test_default_debounce(self) -> None: """Test default debounce time.""" @@ -116,128 +141,180 @@ def test_custom_debounce(self) -> None: watcher = FileWatcher(debounce_ms=500) assert watcher.debounce_ms == 500 - def test_debounce_property_setter(self) -> None: - """Test debounce_ms setter.""" + def test_debounce_setter_accepts_valid(self) -> None: + """Test debounce_ms setter updates both the ms and seconds values.""" watcher = FileWatcher() - watcher.debounce_ms = 250 - assert watcher.debounce_ms == 250 + watcher.debounce_ms = 750 + assert watcher.debounce_ms == 750 + assert watcher._debounce_sec == 0.75 - def test_debounce_minimum(self) -> None: - """Test that debounce has a minimum value.""" + def test_debounce_setter_clamps_to_minimum(self) -> None: + """Test that debounce has an enforced minimum of 10ms.""" watcher = FileWatcher() watcher.debounce_ms = 5 # Below minimum - assert watcher.debounce_ms >= 10 + assert watcher.debounce_ms == 10 + watcher.debounce_ms = -10 + assert watcher.debounce_ms == 10 # ============================================================================= -# FileWatcher Watch/Unwatch Tests (with mocked filesystem) +# FileWatcher Watch/Unwatch Tests # ============================================================================= class TestFileWatcherWatch: """Test FileWatcher watch/unwatch methods.""" - def test_watch_existing_file(self, tmp_path: Path) -> None: - """Test watching an existing file.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") + def test_watch_existing_file(self, watcher: FileWatcher, css_file: Path) -> None: + """Test watching an existing file populates _watches.""" callback = MagicMock() + watcher.watch(css_file, callback, "window1") - watcher = FileWatcher() - watcher.watch(test_file, callback, "window1") - - # Verify internal state - resolved = test_file.resolve() + resolved = css_file.resolve() assert resolved in watcher._watches assert len(watcher._watches[resolved]) == 1 assert watcher._watches[resolved][0].callback is callback - def test_watch_nonexistent_file(self, tmp_path: Path) -> None: - """Test watching a non-existent file does nothing.""" - test_file = tmp_path / "nonexistent.txt" - callback = MagicMock() + def test_watch_nonexistent_file_does_nothing( + self, watcher: FileWatcher, tmp_path: Path + ) -> None: + """Test watching a non-existent file is a no-op.""" + nonexistent = tmp_path / "nonexistent.txt" + watcher.watch(nonexistent, MagicMock(), "window1") + assert nonexistent.resolve() not in watcher._watches - watcher = FileWatcher() - watcher.watch(test_file, callback, "window1") - - # Should not add to watches - resolved = test_file.resolve() - assert resolved not in watcher._watches + def test_watch_path_string(self, watcher: FileWatcher, css_file: Path) -> None: + """Test watching with string path instead of Path object.""" + watcher.watch(str(css_file), MagicMock(), "window1") + assert css_file.resolve() in watcher._watches - def test_watch_multiple_callbacks(self, tmp_path: Path) -> None: + def test_watch_multiple_callbacks(self, watcher: FileWatcher, css_file: Path) -> None: """Test watching same file with multiple callbacks.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - callback1 = MagicMock() - callback2 = MagicMock() - + watcher.watch(css_file, MagicMock(), "window1") + watcher.watch(css_file, MagicMock(), "window2") + assert len(watcher._watches[css_file.resolve()]) == 2 + + def test_watch_while_running_schedules_directory( + self, mocked_observer, tmp_path: Path + ) -> None: + """When watch() is called while running, the directory is scheduled + with the observer. + """ + _, mock_observer = mocked_observer watcher = FileWatcher() - watcher.watch(test_file, callback1, "window1") - watcher.watch(test_file, callback2, "window2") - - resolved = test_file.resolve() - assert len(watcher._watches[resolved]) == 2 + watcher.start() - def test_unwatch_file(self, tmp_path: Path) -> None: - """Test unwatching a file.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - callback = MagicMock() + test_file = tmp_path / "live.css" + test_file.write_text("body {}") + watcher.watch(test_file, MagicMock(), "labelA") - watcher = FileWatcher() - watcher.watch(test_file, callback, "window1") - watcher.unwatch(test_file) + scheduled_paths = [c.args[1] for c in mock_observer.schedule.call_args_list] + assert str(test_file.parent.resolve()) in scheduled_paths - resolved = test_file.resolve() - assert resolved not in watcher._watches + def test_unwatch_file(self, watcher: FileWatcher, css_file: Path) -> None: + """Test unwatching a file removes it from _watches.""" + watcher.watch(css_file, MagicMock(), "window1") + watcher.unwatch(css_file) + assert css_file.resolve() not in watcher._watches - def test_unwatch_by_label(self, tmp_path: Path) -> None: - """Test unwatching by specific label.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - callback1 = MagicMock() - callback2 = MagicMock() + def test_unwatch_nonexistent_is_safe(self, watcher: FileWatcher) -> None: + """Test unwatching a file that was never watched is a no-op.""" + watcher.unwatch(Path("/nonexistent/file.txt")) - watcher = FileWatcher() - watcher.watch(test_file, callback1, "window1") - watcher.watch(test_file, callback2, "window2") - watcher.unwatch(test_file, label="window1") + def test_unwatch_by_label_keeps_other_labels( + self, watcher: FileWatcher, css_file: Path + ) -> None: + """Unwatching with a specific label keeps other labels intact.""" + watcher.watch(css_file, MagicMock(), "win1") + watcher.watch(css_file, MagicMock(), "win2") + watcher.unwatch(css_file, label="win1") - resolved = test_file.resolve() + resolved = css_file.resolve() + assert resolved in watcher._watches assert len(watcher._watches[resolved]) == 1 - assert watcher._watches[resolved][0].label == "window2" - - def test_unwatch_label_all(self, tmp_path: Path) -> None: - """Test unwatching all files for a label.""" + assert watcher._watches[resolved][0].label == "win2" + + def test_unwatch_by_label_removes_path_when_last( + self, watcher: FileWatcher, css_file: Path + ) -> None: + """When unwatching the only remaining label for a path, the path + entry is deleted entirely. + """ + watcher.watch(css_file, MagicMock(), "win-only") + watcher.unwatch(css_file, label="win-only") + assert css_file.resolve() not in watcher._watches + + def test_unwatch_label_all(self, watcher: FileWatcher, tmp_path: Path) -> None: + """unwatch_label removes every file watched only by that label.""" file1 = tmp_path / "test1.txt" file2 = tmp_path / "test2.txt" file1.write_text("content1") file2.write_text("content2") - callback = MagicMock() - watcher = FileWatcher() - watcher.watch(file1, callback, "window1") - watcher.watch(file2, callback, "window1") + watcher.watch(file1, MagicMock(), "window1") + watcher.watch(file2, MagicMock(), "window1") watcher.unwatch_label("window1") assert file1.resolve() not in watcher._watches assert file2.resolve() not in watcher._watches + def test_unwatch_label_keeps_path_when_other_labels_remain( + self, watcher: FileWatcher, tmp_path: Path + ) -> None: + """unwatch_label leaves path entries that other labels still watch.""" + f1 = tmp_path / "one.css" + f2 = tmp_path / "two.css" + f1.write_text("a {}") + f2.write_text("b {}") + + watcher.watch(f1, MagicMock(), "win1") + watcher.watch(f2, MagicMock(), "win1") + watcher.watch(f1, MagicMock(), "win2") # win2 also watches f1 + + watcher.unwatch_label("win1") + + # f2 fully unwatched + assert f2.resolve() not in watcher._watches + # f1 still has win2 watching it + assert f1.resolve() in watcher._watches + assert len(watcher._watches[f1.resolve()]) == 1 + assert watcher._watches[f1.resolve()][0].label == "win2" + + def test_unwatch_label_nonexistent_is_safe(self, watcher: FileWatcher) -> None: + """Test unwatching a label that was never used is a no-op.""" + watcher.unwatch_label("nonexistent_window") + + def test_unwatch_label_cancels_pending_timer( + self, css_file: Path + ) -> None: + """unwatch_label cancels a pending debouncer timer and clears state.""" + watcher = FileWatcher(debounce_ms=10000) # long so timer doesn't fire + watcher.watch(css_file, MagicMock(), "winT") + + # Trigger a change to create a debouncer with an active timer + watcher._on_file_change(css_file.resolve()) + + assert "winT" in watcher._debouncers + debouncer = watcher._debouncers["winT"] + assert debouncer.timer is not None + + watcher.unwatch_label("winT") + + # Debouncer was removed + assert "winT" not in watcher._debouncers + # ============================================================================= -# FileWatcher Start/Stop Tests (with mocked Observer) +# FileWatcher Start/Stop Tests # ============================================================================= class TestFileWatcherStartStop: """Test FileWatcher start/stop methods.""" - @patch("pywry.watcher.Observer") - def test_start_creates_observer(self, mock_observer_class: MagicMock) -> None: - """Test that start creates an Observer.""" - mock_observer = MagicMock() - mock_observer_class.return_value = mock_observer - + def test_start_creates_observer(self, mocked_observer) -> None: + """Test that start creates an Observer and marks watcher running.""" + mock_observer_class, mock_observer = mocked_observer watcher = FileWatcher() watcher.start() @@ -245,24 +322,40 @@ def test_start_creates_observer(self, mock_observer_class: MagicMock) -> None: mock_observer.start.assert_called_once() assert watcher._running is True - @patch("pywry.watcher.Observer") - def test_start_idempotent(self, mock_observer_class: MagicMock) -> None: + def test_start_idempotent(self, mocked_observer) -> None: """Test that calling start twice doesn't create two observers.""" - mock_observer = MagicMock() - mock_observer_class.return_value = mock_observer - + mock_observer_class, _ = mocked_observer watcher = FileWatcher() watcher.start() watcher.start() - mock_observer_class.assert_called_once() - @patch("pywry.watcher.Observer") - def test_stop_stops_observer(self, mock_observer_class: MagicMock) -> None: - """Test that stop stops the Observer.""" - mock_observer = MagicMock() - mock_observer_class.return_value = mock_observer + def test_start_schedules_known_directories( + self, mocked_observer, tmp_path: Path + ) -> None: + """start() must call _ensure_directory_watched for every tracked directory.""" + _, mock_observer = mocked_observer + f1 = tmp_path / "a.css" + f2 = tmp_path / "b.css" + f1.write_text("a {}") + f2.write_text("b {}") + watcher = FileWatcher() + watcher.watch(f1, MagicMock(), "w1") + watcher.watch(f2, MagicMock(), "w1") + + # Not yet running - schedule shouldn't have been called + assert mock_observer.schedule.call_count == 0 + + watcher.start() + + # After start, schedule must have been called for the parent dir + scheduled_paths = [c.args[1] for c in mock_observer.schedule.call_args_list] + assert str(tmp_path.resolve()) in scheduled_paths + + def test_stop_stops_observer(self, mocked_observer) -> None: + """Test that stop stops the Observer and marks watcher not running.""" + _, mock_observer = mocked_observer watcher = FileWatcher() watcher.start() watcher.stop() @@ -271,34 +364,99 @@ def test_stop_stops_observer(self, mock_observer_class: MagicMock) -> None: mock_observer.join.assert_called_once() assert watcher._running is False - @patch("pywry.watcher.Observer") - def test_stop_without_start(self, mock_observer_class: MagicMock) -> None: + def test_stop_without_start_is_safe(self, mocked_observer) -> None: """Test that stop without start is safe.""" + mock_observer_class, _ = mocked_observer watcher = FileWatcher() - watcher.stop() # Should not raise - + watcher.stop() mock_observer_class.assert_not_called() + def test_stop_cancels_pending_timers( + self, mocked_observer, css_file: Path + ) -> None: + """stop() cancels pending debouncer timers and clears them.""" + _, mock_observer = mocked_observer + watcher = FileWatcher(debounce_ms=10000) + watcher.watch(css_file, MagicMock(), "win") + watcher.start() + + # Trigger change so debouncer timer is set + watcher._on_file_change(css_file.resolve()) + debouncer = watcher._debouncers["win"] + assert debouncer.timer is not None + + watcher.stop() + + assert watcher._debouncers == {} + mock_observer.stop.assert_called_once() + assert watcher._running is False + + +# ============================================================================= +# _ensure_directory_watched edge cases +# ============================================================================= + + +class TestEnsureDirectoryWatched: + """Test _ensure_directory_watched branches.""" + + def test_no_observer_returns_early(self, watcher: FileWatcher, tmp_path: Path) -> None: + """Without start(), observer is None and the call is a no-op.""" + watcher._ensure_directory_watched(tmp_path) + + def test_no_handler_returns_early(self, watcher: FileWatcher, tmp_path: Path) -> None: + """With observer set but handler is None, schedule is not called.""" + watcher._observer = MagicMock() + watcher._handler = None + watcher._ensure_directory_watched(tmp_path) + watcher._observer.schedule.assert_not_called() + + def test_already_watched_directory_skips_schedule( + self, mocked_observer, tmp_path: Path + ) -> None: + """Already-watched directories don't get scheduled again.""" + _, mock_observer = mocked_observer + fake_watch = MagicMock() + fake_watch.path = str(tmp_path.resolve()) + fake_emitter = MagicMock() + fake_emitter.watch = fake_watch + mock_observer.emitters = [fake_emitter] + + watcher = FileWatcher() + watcher.start() + mock_observer.schedule.reset_mock() + + watcher._ensure_directory_watched(tmp_path.resolve()) + mock_observer.schedule.assert_not_called() + + def test_schedule_exception_is_swallowed( + self, mocked_observer, tmp_path: Path + ) -> None: + """Schedule failures are logged and swallowed (no exception raised).""" + _, mock_observer = mocked_observer + mock_observer.schedule.side_effect = OSError("nope") + + watcher = FileWatcher() + watcher.start() + # Should not raise even though schedule failed + watcher._ensure_directory_watched(tmp_path) + # ============================================================================= -# File Change Callback Tests (with mocked threading) +# File Change Callback Tests # ============================================================================= class TestFileChangeCallbacks: - """Test file change callback triggering.""" + """Test file change callback triggering and debouncing.""" - def test_on_file_change_triggers_callback(self, tmp_path: Path) -> None: - """Test that file changes trigger callbacks.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") + def test_on_file_change_triggers_callback(self, css_file: Path) -> None: + """Test that file changes trigger callbacks after the debounce period.""" callback = MagicMock() + watcher = FileWatcher(debounce_ms=10) + watcher.watch(css_file, callback, "window1") - watcher = FileWatcher(debounce_ms=10) # Short debounce for test - watcher.watch(test_file, callback, "window1") - - # Simulate file change - resolved = test_file.resolve() + resolved = css_file.resolve() watcher._on_file_change(resolved) # Wait for debounce (longer wait for CI timer scheduling variance) @@ -306,141 +464,146 @@ def test_on_file_change_triggers_callback(self, tmp_path: Path) -> None: callback.assert_called_once_with(resolved, "window1") - def test_on_file_change_unwatched_file(self, tmp_path: Path) -> None: - """Test that changes to unwatched files are ignored.""" - test_file = tmp_path / "test.txt" - unwatched_file = tmp_path / "other.txt" - test_file.write_text("content") + def test_on_file_change_unwatched_file_ignored( + self, watcher: FileWatcher, css_file: Path, tmp_path: Path + ) -> None: + """Changes to unwatched files don't fire callbacks.""" callback = MagicMock() + watcher.watch(css_file, callback, "window1") - watcher = FileWatcher() - watcher.watch(test_file, callback, "window1") - - # Simulate change to unwatched file + # Simulate change to a different, unwatched file + unwatched_file = tmp_path / "other.txt" watcher._on_file_change(unwatched_file.resolve()) - # Wait a bit (ensure debounce timer has time to fire if it would) time.sleep(0.2) - callback.assert_not_called() - def test_debounce_batches_changes(self, tmp_path: Path) -> None: - """Test that rapid changes are batched.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - callback = MagicMock() + def test_on_file_change_empty_watches_returns_early( + self, css_file: Path + ) -> None: + """_on_file_change with an empty watch list returns without scheduling.""" + watcher = FileWatcher(debounce_ms=10) + watcher.watch(css_file, MagicMock(), "w1") + # Manually empty the list to hit the `not watches` branch + watcher._watches[css_file.resolve()] = [] - watcher = FileWatcher(debounce_ms=50) - watcher.watch(test_file, callback, "window1") + watcher._on_file_change(css_file.resolve()) + assert "w1" not in watcher._debouncers - resolved = test_file.resolve() + def test_debounce_batches_rapid_changes(self, css_file: Path) -> None: + """Rapid changes within the debounce window are coalesced into one callback.""" + callback = MagicMock() + watcher = FileWatcher(debounce_ms=50) + watcher.watch(css_file, callback, "window1") - # Simulate rapid changes + resolved = css_file.resolve() watcher._on_file_change(resolved) watcher._on_file_change(resolved) watcher._on_file_change(resolved) - # Wait for debounce (longer wait for CI timer scheduling variance) time.sleep(0.3) - - # Should only call once despite multiple changes assert callback.call_count == 1 + def test_callback_exception_is_handled(self, css_file: Path) -> None: + """Callback exceptions don't crash the watcher; the callback still fires.""" + bad_callback = MagicMock(side_effect=RuntimeError("Callback error")) + watcher = FileWatcher(debounce_ms=10) + watcher.watch(css_file, bad_callback, "window1") -# ============================================================================= -# Global Watcher Functions Tests -# ============================================================================= + watcher._on_file_change(css_file.resolve()) + time.sleep(0.2) + bad_callback.assert_called_once() -class TestGlobalWatcherFunctions: - """Test global watcher utility functions.""" - def test_get_file_watcher_creates_singleton(self) -> None: - """Test that get_file_watcher creates a singleton.""" - # Reset global state - stop_file_watcher() +# ============================================================================= +# _WatchHandler Tests +# ============================================================================= - watcher1 = get_file_watcher() - watcher2 = get_file_watcher() - assert watcher1 is watcher2 +class TestWatchHandler: + """Test _WatchHandler.on_modified event forwarding.""" - # Cleanup - stop_file_watcher() + def test_directory_event_ignored(self) -> None: + """Directory modification events are skipped (not forwarded).""" + watcher_mock = MagicMock() + handler = _WatchHandler(watcher_mock) - def test_get_file_watcher_respects_debounce(self) -> None: - """Test that first call sets debounce time.""" - stop_file_watcher() + event = MagicMock() + event.is_directory = True + handler.on_modified(event) - watcher = get_file_watcher(debounce_ms=250) - assert watcher.debounce_ms == 250 + watcher_mock._on_file_change.assert_not_called() - stop_file_watcher() + def test_file_event_forwards_resolved_path(self, tmp_path: Path) -> None: + """File events are forwarded as resolved Path objects to the watcher.""" + watcher_mock = MagicMock() + handler = _WatchHandler(watcher_mock) - def test_stop_file_watcher(self) -> None: - """Test stopping the global watcher.""" - stop_file_watcher() # Ensure clean state + f = tmp_path / "x.css" + event = MagicMock() + event.is_directory = False + event.src_path = str(f) + handler.on_modified(event) - watcher = get_file_watcher() - watcher.start() + watcher_mock._on_file_change.assert_called_once() + passed_path = watcher_mock._on_file_change.call_args[0][0] + assert isinstance(passed_path, Path) + assert passed_path == f.resolve() - stop_file_watcher() + def test_bytes_src_path_is_decoded(self, tmp_path: Path) -> None: + """Watchdog can deliver src_path as bytes; the handler decodes it.""" + watcher_mock = MagicMock() + handler = _WatchHandler(watcher_mock) - # Getting again should create new instance - watcher2 = get_file_watcher() - assert watcher2 is not watcher + f = tmp_path / "y.css" + event = MagicMock() + event.is_directory = False + event.src_path = str(f).encode() + handler.on_modified(event) - stop_file_watcher() + watcher_mock._on_file_change.assert_called_once() + passed_path = watcher_mock._on_file_change.call_args[0][0] + assert isinstance(passed_path, Path) + assert passed_path == f.resolve() # ============================================================================= -# Edge Cases +# Global Watcher Functions Tests # ============================================================================= -class TestEdgeCases: - """Test edge cases and error handling.""" - - def test_callback_exception_handled(self, tmp_path: Path) -> None: - """Test that callback exceptions don't crash the watcher.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - - bad_callback = MagicMock(side_effect=RuntimeError("Callback error")) - - watcher = FileWatcher(debounce_ms=10) - watcher.watch(test_file, bad_callback, "window1") - - # Simulate file change - should not raise - resolved = test_file.resolve() - watcher._on_file_change(resolved) - - # Wait for debounce timer (longer wait for CI timer scheduling variance) - time.sleep(0.2) - - # Callback was called (and raised) - bad_callback.assert_called_once() - - def test_watch_path_string(self, tmp_path: Path) -> None: - """Test watching with string path instead of Path object.""" - test_file = tmp_path / "test.txt" - test_file.write_text("content") - callback = MagicMock() - - watcher = FileWatcher() - watcher.watch(str(test_file), callback, "window1") - - resolved = test_file.resolve() - assert resolved in watcher._watches +class TestGlobalWatcherFunctions: + """Test get_file_watcher / stop_file_watcher singleton functions.""" - def test_unwatch_nonexistent(self) -> None: - """Test unwatching a file that was never watched.""" - watcher = FileWatcher() - # Should not raise - watcher.unwatch(Path("/nonexistent/file.txt")) + def test_get_file_watcher_creates_singleton(self) -> None: + """Repeated get_file_watcher() calls return the same instance.""" + stop_file_watcher() + try: + w1 = get_file_watcher() + w2 = get_file_watcher() + assert w1 is w2 + finally: + stop_file_watcher() - def test_unwatch_label_nonexistent(self) -> None: - """Test unwatching a label that was never used.""" - watcher = FileWatcher() - # Should not raise - watcher.unwatch_label("nonexistent_window") + def test_get_file_watcher_respects_debounce(self) -> None: + """First call sets the debounce time.""" + stop_file_watcher() + try: + watcher = get_file_watcher(debounce_ms=250) + assert watcher.debounce_ms == 250 + finally: + stop_file_watcher() + + def test_stop_file_watcher_resets_singleton(self) -> None: + """After stop_file_watcher, get_file_watcher creates a fresh instance.""" + stop_file_watcher() + try: + w1 = get_file_watcher() + w1.start() + stop_file_watcher() + + w2 = get_file_watcher() + assert w2 is not w1 + finally: + stop_file_watcher() diff --git a/pywry/tests/test_widget_protocol.py b/pywry/tests/test_widget_protocol.py index 941631d..d2c8513 100644 --- a/pywry/tests/test_widget_protocol.py +++ b/pywry/tests/test_widget_protocol.py @@ -584,3 +584,160 @@ def test_inject_css_with_complex_styles(self, native_handle): result = native_handle.inject_css(css, "complex-styles") assert result is True mock_inject.assert_called_once() + + +# ============================================================================= +# Coverage Gaps — proxy delegation, set_min_size None branches, etc. +# ============================================================================= + + +class TestProxyAndDelegation: + def test_proxy_returns_window_proxy(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + mock_proxy.return_value = MagicMock() + proxy = native_handle.proxy + assert proxy is mock_proxy.return_value + + def test_window_property_raises(self, native_handle): + with pytest.raises(NotImplementedError): + native_handle.window + + def test_emit_fire(self, native_handle): + with patch("pywry.runtime.emit_event_fire") as mock_emit: + native_handle.emit_fire("evt:x", {"a": 1}) + mock_emit.assert_called_once_with("test-window", "evt:x", {"a": 1}) + + def test_set_focus(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_focus() + instance.set_focus.assert_called_once() + + def test_maximize_minimize_center(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.maximize() + native_handle.minimize() + native_handle.center() + instance.maximize.assert_called_once() + instance.minimize.assert_called_once() + instance.center.assert_called_once() + + def test_set_title(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_title("Hello") + instance.set_title.assert_called_once_with("Hello") + + def test_set_size(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_size(800, 600) + instance.set_size.assert_called_once() + + def test_set_min_size_none_clears_constraint(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_min_size(None, None) + instance.set_min_size.assert_called_once_with(None) + + def test_set_min_size_with_dimensions(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_min_size(100, 200) + instance.set_min_size.assert_called_once() + + def test_set_max_size_none_clears_constraint(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_max_size(None, 500) + instance.set_max_size.assert_called_once_with(None) + + def test_set_max_size_with_dimensions(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_max_size(800, 600) + instance.set_max_size.assert_called_once() + + def test_set_always_on_top(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_always_on_top(True) + instance.set_always_on_top.assert_called_once_with(True) + + def test_set_decorations(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_decorations(False) + instance.set_decorations.assert_called_once_with(False) + + def test_set_background_color_opaque(self, native_handle, mock_app): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_background_color(10, 20, 30, 255) + instance.set_background_color.assert_called_once_with((10, 20, 30, 255)) + mock_app.emit.assert_called_with( + "pywry:inject-css", + {"id": "pywry-bg-override", "css": ":root { --pywry-bg-primary: rgb(10, 20, 30) !important; }"}, + "test-window", + ) + + def test_set_background_color_translucent(self, native_handle, mock_app): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_background_color(10, 20, 30, 128) + css = mock_app.emit.call_args[0][1]["css"] + assert "rgba(" in css + + def test_open_devtools(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.open_devtools() + instance.open_devtools.assert_called_once() + + def test_close_devtools(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.close_devtools() + instance.close_devtools.assert_called_once() + + def test_set_zoom(self, native_handle): + with patch("pywry.window_proxy.WindowProxy") as mock_proxy: + instance = MagicMock() + mock_proxy.return_value = instance + native_handle.set_zoom(1.25) + instance.set_zoom.assert_called_once_with(1.25) + + def test_inject_css_generates_id_when_omitted(self, native_handle): + with patch("pywry.runtime.inject_css") as mock_inject: + mock_inject.return_value = True + result = native_handle.inject_css("body{}") + assert result is True + called_args = mock_inject.call_args[0] + assert called_args[0] == "test-window" + assert called_args[2].startswith("pywry-css-") + + def test_remove_css(self, native_handle): + with patch("pywry.runtime.remove_css") as mock_remove: + mock_remove.return_value = True + result = native_handle.remove_css("pywry-css-abc") + assert result is True + + def test_show_alias(self, native_handle): + with patch("pywry.runtime.show_window") as mock_show: + native_handle.show() + mock_show.assert_called_once_with("test-window") diff --git a/pywry/tests/test_window_dispatch.py b/pywry/tests/test_window_dispatch.py index 49d4421..2cd42cc 100644 --- a/pywry/tests/test_window_dispatch.py +++ b/pywry/tests/test_window_dispatch.py @@ -14,12 +14,29 @@ from __future__ import annotations +import sys + from dataclasses import dataclass from typing import Any +from unittest.mock import MagicMock, patch + + +# Coverage compatibility — ``coverage.Coverage.start()`` evicts ``pywry`` from +# ``sys.modules`` after importing it for path discovery. ``pywry.__init__`` +# sets ``sys._pytauri_standalone = True`` and registers +# ``sys.modules['__pytauri_ext_mod__']``, but the eviction can leave the flag +# set while the module is gone. Reset the flag so a fresh import re-runs +# ``_setup_pytauri_standalone`` and ``pytauri.ffi`` imports work. +if getattr(sys, "_pytauri_standalone", False) and "__pytauri_ext_mod__" not in sys.modules: + delattr(sys, "_pytauri_standalone") +from pywry._freeze import _setup_pytauri_standalone -import pytest -from pywry.window_dispatch import ( +_setup_pytauri_standalone() + +import pytest # noqa: E402 + +from pywry.window_dispatch import ( # noqa: E402 APPEARANCE_METHODS, BEHAVIOR_METHODS, COOKIE_METHODS, @@ -30,7 +47,25 @@ STATE_METHODS, VISIBILITY_METHODS, WEBVIEW_METHODS, + _call_appearance_method, + _call_behavior_method, + _call_cookie_method, + _call_cursor_method, + _call_size_position_method, + _call_state_method, + _call_visibility_method, + _call_webview_method, + _extract_position, + _extract_size, + _serialize_cookie, _serialize_monitor, + _set_background_color, + _set_badge_count, + _set_effects, + _set_icon, + _set_overlay_icon, + _set_theme, + _set_title_bar_style, call_window_method, get_window_property, ) @@ -948,3 +983,1026 @@ def test_method_categories_non_empty(self) -> None: assert len(BEHAVIOR_METHODS) > 0 assert len(WEBVIEW_METHODS) > 0 assert len(COOKIE_METHODS) > 0 + + +# ============================================================================= +# Helper extractor and serializer tests (pure functions) +# ============================================================================= + + +class TestExtractPosition: + """Tests for the position extractor.""" + + def test_none(self) -> None: + assert _extract_position(None) is None + + def test_pytauri_wrapper_zero_attr(self) -> None: + """pytauri wraps positions in `_0` tuple attribute.""" + wrapper = MagicMock() + wrapper._0 = (10, 20) + assert _extract_position(wrapper) == {"x": 10, "y": 20} + + def test_direct_tuple(self) -> None: + """Direct tuples are unpacked.""" + assert _extract_position((5, 7)) == {"x": 5, "y": 7} + + def test_object_with_xy_attrs(self) -> None: + """Objects with .x and .y attributes work too.""" + + class P: + x = 1 + y = 2 + + assert _extract_position(P()) == {"x": 1, "y": 2} + + def test_unknown_returns_none(self) -> None: + """Unknown shape returns None.""" + + class Unknown: + pass + + assert _extract_position(Unknown()) is None + + +class TestExtractSize: + """Tests for the size extractor.""" + + def test_none(self) -> None: + assert _extract_size(None) is None + + def test_pytauri_wrapper_zero_attr(self) -> None: + wrapper = MagicMock() + wrapper._0 = (640, 480) + assert _extract_size(wrapper) == {"width": 640, "height": 480} + + def test_direct_tuple(self) -> None: + assert _extract_size((800, 600)) == {"width": 800, "height": 600} + + def test_object_with_wh_attrs(self) -> None: + class S: + width = 1024 + height = 768 + + assert _extract_size(S()) == {"width": 1024, "height": 768} + + def test_unknown_returns_none(self) -> None: + class Unknown: + pass + + assert _extract_size(Unknown()) is None + + +class TestSerializeMonitorAttrStyle: + """Tests for monitor serialization with non-callable attributes.""" + + def test_with_attribute_style_monitor(self) -> None: + """Monitor can have non-callable attributes.""" + + class _Pos: + x = 0 + y = 0 + + class _Size: + width = 1920 + height = 1080 + + class StaticMonitor: + name = "Display 1" + position = _Pos() + size = _Size() + scale_factor = 2.0 + + result = _serialize_monitor(StaticMonitor()) + assert result["name"] == "Display 1" + assert result["position"] == {"x": 0, "y": 0} + assert result["size"] == {"width": 1920, "height": 1080} + assert result["scale_factor"] == 2.0 + + +# ============================================================================= +# get_window_property - additional branches with mocked window +# ============================================================================= + + +class TestGetWindowPropertyEdgeCases: + """Branches that need a Mock-style window (None returns, missing attrs).""" + + def test_url_empty_returns_string(self) -> None: + """url returns empty string when window URL is None.""" + window = MagicMock() + window.url.return_value = None + assert get_window_property(window, "url") == "" + + def test_theme_with_str(self) -> None: + """theme returns str() when theme has no .name attribute.""" + window = MagicMock() + window.theme.return_value = "Light" + # str(theme) uses fallback + assert get_window_property(window, "theme") == "Light" + + def test_inner_position_none(self) -> None: + """inner_position handles None.""" + window = MagicMock() + window.inner_position.return_value = None + assert get_window_property(window, "inner_position") is None + + def test_outer_position_none(self) -> None: + window = MagicMock() + window.outer_position.return_value = None + assert get_window_property(window, "outer_position") is None + + def test_inner_size_none(self) -> None: + window = MagicMock() + window.inner_size.return_value = None + assert get_window_property(window, "inner_size") is None + + def test_outer_size_none(self) -> None: + window = MagicMock() + window.outer_size.return_value = None + assert get_window_property(window, "outer_size") is None + + def test_current_monitor_none(self) -> None: + window = MagicMock() + window.current_monitor.return_value = None + assert get_window_property(window, "current_monitor") is None + + def test_primary_monitor_none(self) -> None: + window = MagicMock() + window.primary_monitor.return_value = None + assert get_window_property(window, "primary_monitor") is None + + def test_available_monitors_empty(self) -> None: + """Empty list returns empty list.""" + window = MagicMock() + window.available_monitors.return_value = [] + assert get_window_property(window, "available_monitors") == [] + + def test_available_monitors_none(self) -> None: + """None list returns empty list.""" + window = MagicMock() + window.available_monitors.return_value = None + assert get_window_property(window, "available_monitors") == [] + + def test_available_monitors_with_valid_entry(self) -> None: + """Single valid monitor serializes correctly.""" + window = MagicMock() + m1 = MagicMock() + m1.name = "M1" + m1.position = MagicMock(_0=(0, 0)) + m1.size = MagicMock(_0=(800, 600)) + m1.scale_factor = 1.0 + window.available_monitors.return_value = [m1] + result = get_window_property(window, "available_monitors") + assert len(result) == 1 + assert result[0]["name"] == "M1" + + def test_is_devtools_open_with_attr(self) -> None: + """is_devtools_open returns the value.""" + window = MagicMock() + window.is_devtools_open.return_value = True + assert get_window_property(window, "is_devtools_open") is True + + def test_is_devtools_open_without_attr(self) -> None: + """is_devtools_open returns False when attribute is missing.""" + + class NoDevtools: + def title(self) -> str: + return "x" + + def url(self) -> str: + return "u" + + assert get_window_property(NoDevtools(), "is_devtools_open") is False + + def test_monitor_from_point_with_args(self) -> None: + """monitor_from_point requires x/y args.""" + window = MagicMock() + m = MagicMock() + m.name = "M" + m.position = MagicMock(_0=(0, 0)) + m.size = MagicMock(_0=(100, 100)) + m.scale_factor = 1.0 + window.monitor_from_point.return_value = m + result = get_window_property(window, "monitor_from_point", {"x": 10, "y": 20}) + assert result["name"] == "M" + + def test_monitor_from_point_returns_none(self) -> None: + """monitor_from_point returning None still serializes.""" + window = MagicMock() + window.monitor_from_point.return_value = None + assert get_window_property(window, "monitor_from_point", {"x": 10, "y": 20}) is None + + def test_monitor_from_point_no_attr(self) -> None: + """monitor_from_point returns None when window doesn't support it.""" + + class NoMonitor: + pass + + assert get_window_property(NoMonitor(), "monitor_from_point", {"x": 10, "y": 20}) is None + + +# ============================================================================= +# Visibility / State / Property setter branches +# ============================================================================= + + +class TestVisibilityDispatch: + """Visibility method branches reached via _call_visibility_method.""" + + def test_close_routes(self) -> None: + window = MagicMock() + _call_visibility_method(window, "close", {}) + window.close.assert_called_once() + + def test_destroy_routes(self) -> None: + window = MagicMock() + _call_visibility_method(window, "destroy", {}) + window.destroy.assert_called_once() + + def test_set_visible_default_true(self) -> None: + """set_visible with no args defaults visible=True -> show().""" + window = MagicMock() + _call_visibility_method(window, "set_visible", {}) + window.show.assert_called_once() + + def test_start_dragging_routes(self) -> None: + window = MagicMock() + _call_visibility_method(window, "start_dragging", {}) + window.start_dragging.assert_called_once() + + def test_start_dragging_missing_attr(self) -> None: + """start_dragging is a no-op when window lacks the method.""" + + class NoDrag: + pass + + # Should not raise + _call_visibility_method(NoDrag(), "start_dragging", {}) + + +class TestStateDispatch: + """State method branches reached via _call_state_method.""" + + def test_center_routes(self) -> None: + window = MagicMock() + _call_state_method(window, "center", {}) + window.center.assert_called_once() + + def test_request_user_attention_none(self) -> None: + """attention_type=None calls with None.""" + window = MagicMock() + _call_state_method(window, "request_user_attention", {}) + window.request_user_attention.assert_called_once_with(None) + + def test_request_user_attention_with_string(self) -> None: + """String attention_type maps via UserAttentionType enum.""" + attn_mod = MagicMock() + attn_mod.CRITICAL = "critical_value" + with patch.dict("sys.modules", {"pytauri": MagicMock(UserAttentionType=attn_mod)}): + window = MagicMock() + _call_state_method( + window, + "request_user_attention", + {"attention_type": "CRITICAL"}, + ) + window.request_user_attention.assert_called_once_with("critical_value") + + def test_request_user_attention_with_object(self) -> None: + """Non-string attention_type passes through unchanged.""" + with patch.dict("sys.modules", {"pytauri": MagicMock()}): + window = MagicMock() + obj = object() + _call_state_method(window, "request_user_attention", {"attention_type": obj}) + window.request_user_attention.assert_called_once_with(obj) + + +class TestPropertySetterDispatch: + """Property setters routed through call_window_method.""" + + def test_set_enabled(self) -> None: + window = MagicMock() + call_window_method(window, "set_enabled", {"enabled": False}) + window.set_enabled.assert_called_once_with(False) + + def test_set_maximizable(self) -> None: + window = MagicMock() + call_window_method(window, "set_maximizable", {"maximizable": False}) + window.set_maximizable.assert_called_once_with(False) + + def test_set_minimizable(self) -> None: + window = MagicMock() + call_window_method(window, "set_minimizable", {"minimizable": False}) + window.set_minimizable.assert_called_once_with(False) + + def test_set_closable(self) -> None: + window = MagicMock() + call_window_method(window, "set_closable", {"closable": False}) + window.set_closable.assert_called_once_with(False) + + def test_set_always_on_bottom(self) -> None: + window = MagicMock() + call_window_method(window, "set_always_on_bottom", {"always_on_bottom": True}) + window.set_always_on_bottom.assert_called_once_with(True) + + def test_set_skip_taskbar(self) -> None: + window = MagicMock() + call_window_method(window, "set_skip_taskbar", {"skip": True}) + window.set_skip_taskbar.assert_called_once_with(True) + + +# ============================================================================= +# Size/Position helper - uses real pytauri.ffi.Position/Size wrappers +# ============================================================================= + + +class TestSizePositionHelpers: + """Helper-level tests for _call_size_position_method. + + Uses real pytauri.ffi.Position/Size types — patching ``sys.modules`` + breaks coverage's import handling on this platform, so we let the + real wrappers run and inspect the ``_0`` tuple that pytauri uses. + """ + + def test_set_size(self) -> None: + window = MagicMock() + _call_size_position_method(window, "set_size", {"width": 800, "height": 600}) + called_size = window.set_size.call_args[0][0] + assert called_size._0 == (800, 600) + + def test_set_size_default(self) -> None: + """Defaults to 800x600 when no width/height provided.""" + window = MagicMock() + _call_size_position_method(window, "set_size", {}) + called_size = window.set_size.call_args[0][0] + assert called_size._0 == (800, 600) + + def test_set_min_size_with_values(self) -> None: + window = MagicMock() + _call_size_position_method(window, "set_min_size", {"width": 100, "height": 100}) + called = window.set_min_size.call_args[0][0] + assert called._0 == (100, 100) + + def test_set_min_size_none(self) -> None: + """No width/height passes None.""" + window = MagicMock() + _call_size_position_method(window, "set_min_size", {}) + window.set_min_size.assert_called_once_with(None) + + def test_set_max_size_with_values(self) -> None: + window = MagicMock() + _call_size_position_method(window, "set_max_size", {"width": 2000, "height": 2000}) + called = window.set_max_size.call_args[0][0] + assert called._0 == (2000, 2000) + + def test_set_max_size_none(self) -> None: + window = MagicMock() + _call_size_position_method(window, "set_max_size", {}) + window.set_max_size.assert_called_once_with(None) + + def test_set_size_constraints_min_only(self) -> None: + window = MagicMock() + _call_size_position_method( + window, + "set_size_constraints", + {"min_size": {"width": 200, "height": 100}}, + ) + window.set_min_size.assert_called_once() + window.set_max_size.assert_not_called() + + def test_set_size_constraints_max_only(self) -> None: + window = MagicMock() + _call_size_position_method( + window, + "set_size_constraints", + {"max_size": {"width": 2000, "height": 1500}}, + ) + window.set_max_size.assert_called_once() + window.set_min_size.assert_not_called() + + def test_set_size_constraints_both(self) -> None: + window = MagicMock() + _call_size_position_method( + window, + "set_size_constraints", + { + "min_size": {"width": 200, "height": 100}, + "max_size": {"width": 2000, "height": 1500}, + }, + ) + window.set_min_size.assert_called_once() + window.set_max_size.assert_called_once() + + def test_set_position(self) -> None: + window = MagicMock() + _call_size_position_method(window, "set_position", {"x": 100, "y": 200}) + called = window.set_position.call_args[0][0] + assert called._0 == (100, 200) + + def test_set_position_default(self) -> None: + window = MagicMock() + _call_size_position_method(window, "set_position", {}) + called = window.set_position.call_args[0][0] + assert called._0 == (0, 0) + + +# ============================================================================= +# Appearance helpers +# ============================================================================= + + +class TestAppearanceDispatch: + """Tests for appearance method handlers.""" + + def test_set_background_color_dict(self) -> None: + """RGBA dict is unpacked correctly.""" + window = MagicMock() + _set_background_color(window, {"color": {"r": 10, "g": 20, "b": 30, "a": 200}}) + window.set_background_color.assert_called_once_with((10, 20, 30, 200)) + + def test_set_background_color_list(self) -> None: + """List/tuple is converted to tuple.""" + window = MagicMock() + _set_background_color(window, {"color": [255, 100, 50, 255]}) + window.set_background_color.assert_called_once_with((255, 100, 50, 255)) + + def test_set_background_color_unknown_type(self) -> None: + """Unknown color type defaults to (0,0,0,255).""" + window = MagicMock() + _set_background_color(window, {"color": "not-a-color"}) + window.set_background_color.assert_called_once_with((0, 0, 0, 255)) + + def test_set_background_color_no_color_key(self) -> None: + """Falls back to args dict (which is also a dict).""" + window = MagicMock() + _set_background_color(window, {"r": 1, "g": 2, "b": 3, "a": 4}) + window.set_background_color.assert_called_once_with((1, 2, 3, 4)) + + def test_set_theme_none(self) -> None: + """No theme key calls set_theme(None).""" + window = MagicMock() + _set_theme(window, {}) + window.set_theme.assert_called_once_with(None) + + def test_set_theme_string(self) -> None: + """String theme is mapped via Theme[...].""" + fake_theme = MagicMock() + fake_theme.__getitem__ = MagicMock(return_value="Light_Theme") + with patch.dict("sys.modules", {"pytauri": MagicMock(Theme=fake_theme)}): + window = MagicMock() + _set_theme(window, {"theme": "light"}) + window.set_theme.assert_called_once_with("Light_Theme") + + def test_set_theme_object(self) -> None: + """Non-string theme passes through unchanged.""" + with patch.dict("sys.modules", {"pytauri": MagicMock()}): + window = MagicMock() + obj = object() + _set_theme(window, {"theme": obj}) + window.set_theme.assert_called_once_with(obj) + + def test_set_title_bar_style_no_attr(self) -> None: + """If window lacks set_title_bar_style, no-op.""" + + class NoTBS: + pass + + # Should not raise + _set_title_bar_style(NoTBS(), {"style": "Visible"}) + + def test_set_title_bar_style_string(self) -> None: + """String style is mapped via TitleBarStyle..""" + window = MagicMock() + fake_tbs = MagicMock() + fake_tbs.Visible = "VisibleEnum" + fake_window_mod = MagicMock(TitleBarStyle=fake_tbs) + with patch.dict("sys.modules", {"pytauri.window": fake_window_mod}): + _set_title_bar_style(window, {"style": "Visible"}) + window.set_title_bar_style.assert_called_once_with("VisibleEnum") + + def test_set_title_bar_style_object(self) -> None: + """Non-string style passes through.""" + window = MagicMock() + with patch.dict("sys.modules", {"pytauri.window": MagicMock()}): + obj = object() + _set_title_bar_style(window, {"style": obj}) + window.set_title_bar_style.assert_called_once_with(obj) + + def test_set_effects_no_attr(self) -> None: + """If window lacks set_effects, no-op.""" + + class NoFx: + pass + + _set_effects(NoFx(), {"effects": {"effects": [], "state": "Active"}}) + + def test_set_effects_no_data(self) -> None: + """Empty effects_data is no-op.""" + window = MagicMock() + _set_effects(window, {}) + window.set_effects.assert_not_called() + + def test_set_effects_with_strings(self) -> None: + """String effects/state mapped via enums.""" + window = MagicMock() + + fake_effect = MagicMock() + fake_effect.MICA = "mica_val" + fake_state = MagicMock() + fake_state.ACTIVE = "active_val" + fake_effects_class = MagicMock() + fake_window_mod = MagicMock( + Effect=fake_effect, EffectState=fake_state, Effects=fake_effects_class + ) + with patch.dict("sys.modules", {"pytauri.window": fake_window_mod}): + _set_effects( + window, + { + "effects": { + "effects": ["MICA"], + "state": "ACTIVE", + "radius": 8.0, + "color": [0, 0, 0, 255], + } + }, + ) + window.set_effects.assert_called_once() + + def test_set_effects_with_objects(self) -> None: + """Non-string effects/state passes through.""" + window = MagicMock() + fake_window_mod = MagicMock() + with patch.dict("sys.modules", {"pytauri.window": fake_window_mod}): + _set_effects( + window, + { + "effects": { + "effects": [object()], + "state": object(), + } + }, + ) + window.set_effects.assert_called_once() + + def test_set_icon_with_data(self) -> None: + """Base64 icon bytes are decoded and passed.""" + import base64 + + window = MagicMock() + raw = b"\x00\x01\x02" + icon_b64 = base64.b64encode(raw).decode() + _set_icon(window, {"icon": icon_b64}) + window.set_icon.assert_called_once_with(raw) + + def test_set_icon_no_data_with_attr(self) -> None: + """No icon and window has set_icon attribute -> calls with None.""" + window = MagicMock() + _set_icon(window, {}) + window.set_icon.assert_called_once_with(None) + + def test_set_icon_no_data_no_attr(self) -> None: + """No icon and window lacks set_icon attribute -> no-op.""" + + class NoIcon: + pass + + _set_icon(NoIcon(), {}) + + def test_set_icon_with_data_no_attr(self) -> None: + """Icon data with no set_icon attribute is no-op.""" + import base64 + + class NoIcon: + pass + + b = base64.b64encode(b"x").decode() + _set_icon(NoIcon(), {"icon": b}) + + def test_set_badge_count_with_attr(self) -> None: + window = MagicMock() + _set_badge_count(window, {"count": 5}) + window.set_badge_count.assert_called_once_with(5) + + def test_set_badge_count_no_attr(self) -> None: + class NoBadge: + pass + + _set_badge_count(NoBadge(), {"count": 5}) + + def test_set_overlay_icon_with_data(self) -> None: + import base64 + + window = MagicMock() + raw = b"\x10\x20" + icon_b64 = base64.b64encode(raw).decode() + _set_overlay_icon(window, {"icon": icon_b64}) + window.set_overlay_icon.assert_called_once_with(raw) + + def test_set_overlay_icon_no_data_with_attr(self) -> None: + window = MagicMock() + _set_overlay_icon(window, {}) + window.set_overlay_icon.assert_called_once_with(None) + + def test_set_overlay_icon_no_attr_no_data(self) -> None: + class NoOverlay: + pass + + _set_overlay_icon(NoOverlay(), {}) + + def test_set_overlay_icon_with_data_no_attr(self) -> None: + """Icon data with no set_overlay_icon attribute is no-op.""" + import base64 + + class NoOverlay: + pass + + b = base64.b64encode(b"x").decode() + _set_overlay_icon(NoOverlay(), {"icon": b}) + + def test_appearance_dispatch_unknown_method(self) -> None: + """Unknown appearance method is a no-op.""" + window = MagicMock() + _call_appearance_method(window, "set_unknown", {}) + + def test_set_content_protected_via_dispatch(self) -> None: + """set_content_protected dispatches correctly.""" + window = MagicMock() + _call_appearance_method(window, "set_content_protected", {"protected": True}) + window.set_content_protected.assert_called_once_with(True) + + def test_set_shadow_via_dispatch(self) -> None: + window = MagicMock() + _call_appearance_method(window, "set_shadow", {"shadow": False}) + window.set_shadow.assert_called_once_with(False) + + +# ============================================================================= +# Cursor helpers +# ============================================================================= + + +class TestCursorDispatch: + """Tests for cursor method handlers.""" + + def test_set_cursor_icon_string(self) -> None: + """String icon mapped via CursorIcon enum.""" + window = MagicMock() + fake_cursor = MagicMock() + fake_cursor.Hand = "hand_value" + with patch.dict("sys.modules", {"pytauri": MagicMock(CursorIcon=fake_cursor)}): + _call_cursor_method(window, "set_cursor_icon", {"icon": "Hand"}) + window.set_cursor_icon.assert_called_once_with("hand_value") + + def test_set_cursor_icon_object(self) -> None: + """Non-string icon passes through.""" + window = MagicMock() + with patch.dict("sys.modules", {"pytauri": MagicMock()}): + obj = object() + _call_cursor_method(window, "set_cursor_icon", {"icon": obj}) + window.set_cursor_icon.assert_called_once_with(obj) + + def test_set_cursor_icon_no_attr(self) -> None: + """If window lacks set_cursor_icon, no-op.""" + + class NoCursor: + pass + + _call_cursor_method(NoCursor(), "set_cursor_icon", {"icon": "Hand"}) + + def test_set_cursor_position(self) -> None: + window = MagicMock() + _call_cursor_method(window, "set_cursor_position", {"x": 10, "y": 20}) + window.set_cursor_position.assert_called_once_with((10.0, 20.0)) + + def test_set_cursor_position_default(self) -> None: + window = MagicMock() + _call_cursor_method(window, "set_cursor_position", {}) + window.set_cursor_position.assert_called_once_with((0.0, 0.0)) + + def test_set_cursor_position_no_attr(self) -> None: + class NoCP: + pass + + _call_cursor_method(NoCP(), "set_cursor_position", {"x": 1, "y": 2}) + + def test_set_cursor_visible(self) -> None: + window = MagicMock() + _call_cursor_method(window, "set_cursor_visible", {"visible": False}) + window.set_cursor_visible.assert_called_once_with(False) + + def test_set_cursor_visible_no_attr(self) -> None: + class NoCV: + pass + + _call_cursor_method(NoCV(), "set_cursor_visible", {"visible": True}) + + def test_set_cursor_grab(self) -> None: + window = MagicMock() + _call_cursor_method(window, "set_cursor_grab", {"grab": True}) + window.set_cursor_grab.assert_called_once_with(True) + + def test_set_cursor_grab_no_attr(self) -> None: + class NoCG: + pass + + _call_cursor_method(NoCG(), "set_cursor_grab", {"grab": False}) + + +# ============================================================================= +# Behavior helpers +# ============================================================================= + + +class TestBehaviorDispatch: + """Tests for behavior methods.""" + + def test_set_ignore_cursor_events(self) -> None: + window = MagicMock() + _call_behavior_method(window, "set_ignore_cursor_events", {"ignore": True}) + window.set_ignore_cursor_events.assert_called_once_with(True) + + def test_set_ignore_cursor_events_no_attr(self) -> None: + class NoIgnore: + pass + + _call_behavior_method(NoIgnore(), "set_ignore_cursor_events", {"ignore": True}) + + def test_set_progress_bar_no_attr(self) -> None: + """If window lacks set_progress_bar, no-op.""" + + class NoPB: + pass + + _call_behavior_method(NoPB(), "set_progress_bar", {"state": {"status": "Normal"}}) + + def test_set_progress_bar_no_state(self) -> None: + """Empty state dict is no-op.""" + window = MagicMock() + _call_behavior_method(window, "set_progress_bar", {}) + window.set_progress_bar.assert_not_called() + + def test_set_progress_bar_with_string_status(self) -> None: + """String status mapped via ProgressBarStatus.""" + window = MagicMock() + + fake_status = MagicMock() + fake_status.Normal = "normal_val" + fake_state_class = MagicMock() + fake_window_mod = MagicMock( + ProgressBarStatus=fake_status, ProgressBarState=fake_state_class + ) + with patch.dict("sys.modules", {"pytauri.window": fake_window_mod}): + _call_behavior_method( + window, + "set_progress_bar", + {"state": {"status": "Normal", "progress": 50}}, + ) + window.set_progress_bar.assert_called_once() + + def test_set_progress_bar_with_object_status(self) -> None: + """Non-string status passes through.""" + window = MagicMock() + fake_state_class = MagicMock() + fake_window_mod = MagicMock(ProgressBarState=fake_state_class) + with patch.dict("sys.modules", {"pytauri.window": fake_window_mod}): + _call_behavior_method( + window, + "set_progress_bar", + {"state": {"status": object(), "progress": 50}}, + ) + window.set_progress_bar.assert_called_once() + + def test_set_visible_on_all_workspaces(self) -> None: + window = MagicMock() + _call_behavior_method(window, "set_visible_on_all_workspaces", {"visible": True}) + window.set_visible_on_all_workspaces.assert_called_once_with(True) + + def test_set_visible_on_all_workspaces_no_attr(self) -> None: + class NoVOAW: + pass + + _call_behavior_method(NoVOAW(), "set_visible_on_all_workspaces", {"visible": True}) + + def test_set_traffic_light_position(self) -> None: + window = MagicMock() + _call_behavior_method(window, "set_traffic_light_position", {"x": 5, "y": 10}) + window.set_traffic_light_position.assert_called_once_with((5.0, 10.0)) + + def test_set_traffic_light_position_no_attr(self) -> None: + class NoTLP: + pass + + _call_behavior_method(NoTLP(), "set_traffic_light_position", {"x": 1, "y": 2}) + + def test_set_traffic_light_position_default(self) -> None: + window = MagicMock() + _call_behavior_method(window, "set_traffic_light_position", {}) + window.set_traffic_light_position.assert_called_once_with((0.0, 0.0)) + + +# ============================================================================= +# Webview missing branches +# ============================================================================= + + +class TestWebviewDispatchEdgeCases: + """Webview branches not exercised by the StubWindow integration tests.""" + + def test_navigate_no_attr(self) -> None: + """If window has no .navigate, no-op.""" + + class NoNav: + pass + + _call_webview_method(NoNav(), "navigate", {"url": "about:blank"}) + + def test_open_devtools_no_attr(self) -> None: + class NoDT: + pass + + _call_webview_method(NoDT(), "open_devtools", {}) + + def test_close_devtools_no_attr(self) -> None: + class NoDT: + pass + + _call_webview_method(NoDT(), "close_devtools", {}) + + def test_is_devtools_open_returns_value(self) -> None: + window = MagicMock() + window.is_devtools_open.return_value = True + assert call_window_method(window, "is_devtools_open", {}) is True + + def test_is_devtools_open_no_attr(self) -> None: + """Without attr returns False.""" + + class NoDT: + pass + + assert call_window_method(NoDT(), "is_devtools_open", {}) is False + + def test_set_zoom(self) -> None: + window = MagicMock() + call_window_method(window, "set_zoom", {"scale": 1.5}) + window.set_zoom.assert_called_once_with(1.5) + + def test_set_zoom_no_attr(self) -> None: + class NoZoom: + pass + + call_window_method(NoZoom(), "set_zoom", {"scale": 2.0}) + + def test_zoom_returns_value(self) -> None: + window = MagicMock() + window.zoom.return_value = 2.0 + assert call_window_method(window, "zoom", {}) == 2.0 + + def test_zoom_no_attr(self) -> None: + class NoZoom: + pass + + assert call_window_method(NoZoom(), "zoom", {}) == 1.0 + + def test_clear_browsing_data(self) -> None: + window = MagicMock() + call_window_method(window, "clear_all_browsing_data", {}) + window.clear_all_browsing_data.assert_called_once() + + def test_clear_browsing_data_no_attr(self) -> None: + class No: + pass + + call_window_method(No(), "clear_all_browsing_data", {}) + + def test_reload_no_attr(self) -> None: + class No: + pass + + call_window_method(No(), "reload", {}) + + def test_print(self) -> None: + window = MagicMock() + call_window_method(window, "print", {}) + window.print.assert_called_once() + + def test_print_no_attr(self) -> None: + class No: + pass + + call_window_method(No(), "print", {}) + + def test_webview_dispatch_unknown(self) -> None: + """Unknown method returns None.""" + window = MagicMock() + assert _call_webview_method(window, "unknown_webview", {}) is None + + +# ============================================================================= +# Cookie helpers +# ============================================================================= + + +class TestCookieDispatch: + """Tests for cookie helper methods.""" + + def test_serialize_cookie_none(self) -> None: + assert _serialize_cookie(None) == {} + + def test_serialize_cookie_full(self) -> None: + """Full cookie attributes serialized correctly.""" + + class C: + name = "session" + value = "abc" + domain = ".example.com" + path = "/" + expires = 1234567 + http_only = True + secure = True + same_site = "Strict" + + d = _serialize_cookie(C()) + assert d["name"] == "session" + assert d["value"] == "abc" + assert d["domain"] == ".example.com" + assert d["path"] == "/" + assert d["http_only"] is True + assert d["secure"] is True + assert d["same_site"] == "Strict" + + def test_serialize_cookie_partial(self) -> None: + """Missing attrs use defaults.""" + + class C: + name = "n" + value = "v" + + d = _serialize_cookie(C()) + assert d["name"] == "n" + assert d["value"] == "v" + assert d["domain"] is None + assert d["http_only"] is False + + def test_cookie_set(self) -> None: + window = MagicMock() + _call_cookie_method(window, "set_cookie", {"cookie": {"name": "n"}}) + window.set_cookie.assert_called_once_with({"name": "n"}) + + def test_cookie_set_no_attr(self) -> None: + class No: + pass + + assert _call_cookie_method(No(), "set_cookie", {}) is None + + def test_cookie_get_with_cookies(self) -> None: + """get_cookies returns serialized cookies.""" + window = MagicMock() + + class C: + name = "n" + value = "v" + + window.get_cookies.return_value = [C()] + result = _call_cookie_method(window, "get_cookies", {}) + assert isinstance(result, list) + assert result[0]["name"] == "n" + + def test_cookie_get_empty(self) -> None: + """get_cookies returns empty list for None.""" + window = MagicMock() + window.get_cookies.return_value = None + assert _call_cookie_method(window, "get_cookies", {}) == [] + + def test_cookie_get_no_attr(self) -> None: + """get_cookies on a window without get_cookies returns [].""" + + class No: + pass + + assert _call_cookie_method(No(), "get_cookies", {}) == [] + + def test_cookie_remove(self) -> None: + window = MagicMock() + _call_cookie_method(window, "remove_cookie", {"name": "n", "url": "https://x.com"}) + window.remove_cookie.assert_called_once_with("n", "https://x.com") + + def test_cookie_remove_no_attr(self) -> None: + class No: + pass + + _call_cookie_method(No(), "remove_cookie", {"name": "n"}) + + def test_cookie_remove_all(self) -> None: + window = MagicMock() + _call_cookie_method(window, "remove_all_cookies", {}) + window.remove_all_cookies.assert_called_once() + + def test_cookie_remove_all_no_attr(self) -> None: + class No: + pass + + _call_cookie_method(No(), "remove_all_cookies", {}) + + def test_cookie_dispatch_unknown(self) -> None: + """Unknown cookie method returns None.""" + window = MagicMock() + assert _call_cookie_method(window, "unknown_cookie_method", {}) is None diff --git a/pywry/tests/test_window_proxy.py b/pywry/tests/test_window_proxy.py index 608ccc3..20ce251 100644 --- a/pywry/tests/test_window_proxy.py +++ b/pywry/tests/test_window_proxy.py @@ -1,9 +1,14 @@ -"""Integration tests for WindowProxy. +"""Tests for WindowProxy. -These tests verify that WindowProxy methods actually work by creating -real windows and testing real operations. No mocks - real execution. +This file contains BOTH: -Tests are marked slow because they spawn actual subprocess/windows. +1. Integration tests at the top — they spawn a real pytauri subprocess and + verify proxy methods drive the underlying window. They are slow and may + be skipped on headless CI. +2. Unit tests at the bottom — they patch ``pywry.window_proxy.runtime`` so + they run in milliseconds without a subprocess. These provide reliable + coverage in headless environments where the integration tests cannot + spawn windows. """ from __future__ import annotations @@ -15,6 +20,7 @@ from collections.abc import Callable from functools import wraps from typing import Any, TypeVar +from unittest.mock import MagicMock, patch import pytest @@ -23,7 +29,22 @@ from pywry.callbacks import get_registry from pywry.exceptions import IPCTimeoutError from pywry.models import ThemeMode, WindowMode -from pywry.types import PhysicalPosition, PhysicalSize +from pywry.types import ( + Cookie, + CursorIcon, + Effect, + EffectState, + LogicalPosition, + LogicalSize, + PhysicalPosition, + PhysicalSize, + ProgressBarStatus, + Theme, + TitleBarStyle, + UserAttentionType, + serialize_position, + serialize_size, +) from pywry.window_proxy import WindowProxy # Import shared test utilities from tests.conftest @@ -518,3 +539,589 @@ def test_independent_windows(self) -> None: assert "Modified 1" in proxy1.title assert "Win 2" in proxy2.title or proxy2.title != proxy1.title app.close() + + +# ============================================================================= +# Unit tests via mocked runtime — fast, headless-safe coverage +# ============================================================================= + + +@patch("pywry.window_proxy.runtime") +class TestWindowProxyMockedProperties: + """Test WindowProxy property getters via mocked runtime.""" + + def test_label(self, runtime_mock: MagicMock) -> None: + proxy = WindowProxy("test") + assert proxy.label == "test" + + def test_title(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = "T" + proxy = WindowProxy("x") + assert proxy.title == "T" + runtime_mock.window_get.assert_called_with("x", "title") + + def test_url(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = "http://x" + proxy = WindowProxy("x") + assert proxy.url == "http://x" + + def test_theme(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = "Dark" + proxy = WindowProxy("x") + assert proxy.theme == Theme.DARK + + def test_scale_factor(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = 2.0 + proxy = WindowProxy("x") + assert proxy.scale_factor == 2.0 + + def test_inner_position(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = {"x": 10, "y": 20} + proxy = WindowProxy("x") + result = proxy.inner_position + assert isinstance(result, PhysicalPosition) + assert result.x == 10 + assert result.y == 20 + + def test_outer_position(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = {"x": 5, "y": 6} + proxy = WindowProxy("x") + result = proxy.outer_position + assert isinstance(result, PhysicalPosition) + assert result.x == 5 + + def test_inner_size(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = {"width": 800, "height": 600} + proxy = WindowProxy("x") + result = proxy.inner_size + assert isinstance(result, PhysicalSize) + assert result.width == 800 + + def test_outer_size(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = {"width": 1024, "height": 768} + proxy = WindowProxy("x") + result = proxy.outer_size + assert isinstance(result, PhysicalSize) + + def test_cursor_position(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = {"x": 1, "y": 2} + proxy = WindowProxy("x") + result = proxy.cursor_position + assert isinstance(result, PhysicalPosition) + + def test_current_monitor_present(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = { + "name": "M", + "size": {"width": 1, "height": 2}, + "position": {"x": 0, "y": 0}, + "scale_factor": 1.0, + } + proxy = WindowProxy("x") + m = proxy.current_monitor + assert m is not None + assert m.name == "M" + + def test_current_monitor_none(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = None + assert WindowProxy("x").current_monitor is None + + def test_primary_monitor_present(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = { + "name": "P", + "size": {"width": 1, "height": 2}, + "position": {"x": 0, "y": 0}, + "scale_factor": 1.0, + } + assert WindowProxy("x").primary_monitor.name == "P" + + def test_primary_monitor_none(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = None + assert WindowProxy("x").primary_monitor is None + + def test_available_monitors(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = [ + { + "name": "M1", + "size": {"width": 1, "height": 1}, + "position": {"x": 0, "y": 0}, + "scale_factor": 1.0, + } + ] + result = WindowProxy("x").available_monitors + assert isinstance(result, list) + assert len(result) == 1 + + @pytest.mark.parametrize( + "prop", + [ + "is_fullscreen", + "is_minimized", + "is_maximized", + "is_focused", + "is_decorated", + "is_resizable", + "is_enabled", + "is_visible", + "is_closable", + "is_maximizable", + "is_minimizable", + "is_always_on_top", + "is_always_on_bottom", + "is_devtools_open", + ], + ) + def test_boolean_props(self, runtime_mock: MagicMock, prop: str) -> None: + runtime_mock.window_get.return_value = True + assert getattr(WindowProxy("x"), prop) is True + + +@patch("pywry.window_proxy.runtime") +class TestWindowProxyMockedActions: + """Test WindowProxy action methods route to runtime.""" + + def test_show(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").show() + runtime_mock.window_call.assert_called_once_with("x", "show") + + def test_hide(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").hide() + runtime_mock.window_call.assert_called_once_with("x", "hide") + + def test_close(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").close() + runtime_mock.window_call.assert_called_once_with("x", "close") + + def test_destroy(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").destroy() + runtime_mock.window_call.assert_called_once_with("x", "destroy") + + def test_maximize(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").maximize() + runtime_mock.window_call.assert_called_once_with("x", "maximize") + + def test_unmaximize(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").unmaximize() + runtime_mock.window_call.assert_called_once_with("x", "unmaximize") + + def test_minimize(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").minimize() + runtime_mock.window_call.assert_called_once_with("x", "minimize") + + def test_unminimize(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").unminimize() + runtime_mock.window_call.assert_called_once_with("x", "unminimize") + + def test_toggle_maximize(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").toggle_maximize() + runtime_mock.window_call.assert_called_once_with("x", "toggle_maximize") + + def test_center(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").center() + runtime_mock.window_call.assert_called_once_with("x", "center") + + def test_set_focus(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_focus() + runtime_mock.window_call.assert_called_once_with("x", "set_focus") + + def test_reload(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").reload() + runtime_mock.window_call.assert_called_once_with("x", "reload") + + def test_print_page(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").print_page() + runtime_mock.window_call.assert_called_once_with("x", "print") + + def test_open_devtools(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").open_devtools() + runtime_mock.window_call.assert_called_once_with("x", "open_devtools") + + def test_close_devtools(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").close_devtools() + runtime_mock.window_call.assert_called_once_with("x", "close_devtools") + + def test_clear_all_browsing_data(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").clear_all_browsing_data() + runtime_mock.window_call.assert_called_once_with("x", "clear_all_browsing_data") + + def test_start_dragging(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").start_dragging() + runtime_mock.window_call.assert_called_once_with("x", "start_dragging") + + def test_request_user_attention_with_type(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").request_user_attention(UserAttentionType.CRITICAL) + args = runtime_mock.window_call.call_args[0] + assert args[1] == "request_user_attention" + assert args[2] == {"attention_type": "Critical"} + + def test_request_user_attention_none(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").request_user_attention(None) + assert runtime_mock.window_call.call_args[0][2] == {"attention_type": None} + + def test_set_title(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_title("new") + runtime_mock.window_call.assert_called_once_with("x", "set_title", {"title": "new"}) + + def test_set_size(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_size(PhysicalSize(800, 600)) + args = runtime_mock.window_call.call_args[0] + assert args[1] == "set_size" + assert args[2]["width"] == 800 + + def test_set_min_size_with_size(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_min_size(LogicalSize(100, 100)) + runtime_mock.window_call.assert_called_once() + + def test_set_min_size_none(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_min_size(None) + runtime_mock.window_call.assert_called_once_with("x", "set_min_size", {}) + + def test_set_max_size_with_size(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_max_size(PhysicalSize(2000, 1500)) + runtime_mock.window_call.assert_called_once() + + def test_set_max_size_none(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_max_size(None) + runtime_mock.window_call.assert_called_once_with("x", "set_max_size", {}) + + def test_set_position(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_position(PhysicalPosition(100, 200)) + args = runtime_mock.window_call.call_args[0] + assert args[1] == "set_position" + + def test_set_fullscreen(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_fullscreen(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_fullscreen", {"fullscreen": True} + ) + + def test_set_decorations(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_decorations(False) + runtime_mock.window_call.assert_called_once_with( + "x", "set_decorations", {"decorations": False} + ) + + def test_set_always_on_top(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_always_on_top(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_always_on_top", {"always_on_top": True} + ) + + def test_set_always_on_bottom(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_always_on_bottom(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_always_on_bottom", {"always_on_bottom": True} + ) + + def test_set_resizable(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_resizable(False) + runtime_mock.window_call.assert_called_once_with( + "x", "set_resizable", {"resizable": False} + ) + + def test_set_enabled(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_enabled(False) + runtime_mock.window_call.assert_called_once_with("x", "set_enabled", {"enabled": False}) + + def test_set_closable(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_closable(False) + runtime_mock.window_call.assert_called_once_with("x", "set_closable", {"closable": False}) + + def test_set_maximizable(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_maximizable(False) + runtime_mock.window_call.assert_called_once_with( + "x", "set_maximizable", {"maximizable": False} + ) + + def test_set_minimizable(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_minimizable(False) + runtime_mock.window_call.assert_called_once_with( + "x", "set_minimizable", {"minimizable": False} + ) + + def test_set_visible_on_all_workspaces(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_visible_on_all_workspaces(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_visible_on_all_workspaces", {"visible": True} + ) + + def test_set_skip_taskbar(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_skip_taskbar(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_skip_taskbar", {"skip": True} + ) + + def test_set_cursor_icon(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_cursor_icon(CursorIcon.HAND) + runtime_mock.window_call.assert_called_once_with( + "x", "set_cursor_icon", {"icon": "Hand"} + ) + + def test_set_cursor_position(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_cursor_position(LogicalPosition(10.0, 20.0)) + runtime_mock.window_call.assert_called_once() + + def test_set_cursor_visible(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_cursor_visible(False) + runtime_mock.window_call.assert_called_once_with( + "x", "set_cursor_visible", {"visible": False} + ) + + def test_set_cursor_grab(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_cursor_grab(True) + runtime_mock.window_call.assert_called_once_with("x", "set_cursor_grab", {"grab": True}) + + def test_set_ignore_cursor_events(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_ignore_cursor_events(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_ignore_cursor_events", {"ignore": True} + ) + + def test_set_icon_with_bytes(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_icon(b"png_data") + args = runtime_mock.window_call.call_args[0] + assert args[1] == "set_icon" + assert args[2]["icon"] is not None # base64 encoded + + def test_set_icon_none(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_icon(None) + args = runtime_mock.window_call.call_args[0] + assert args[2]["icon"] is None + + def test_set_shadow(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_shadow(False) + runtime_mock.window_call.assert_called_once_with("x", "set_shadow", {"enable": False}) + + def test_set_title_bar_style(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_title_bar_style(TitleBarStyle.OVERLAY) + runtime_mock.window_call.assert_called_once_with( + "x", "set_title_bar_style", {"style": "Overlay"} + ) + + def test_set_theme_with_value(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_theme(Theme.LIGHT) + runtime_mock.window_call.assert_called_once_with("x", "set_theme", {"theme": "Light"}) + + def test_set_theme_none(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_theme(None) + runtime_mock.window_call.assert_called_once_with("x", "set_theme", {"theme": None}) + + def test_set_content_protected(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_content_protected(True) + runtime_mock.window_call.assert_called_once_with( + "x", "set_content_protected", {"protected": True} + ) + + def test_set_traffic_light_position(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_traffic_light_position(10.0, 20.0) + runtime_mock.window_call.assert_called_once_with( + "x", "set_traffic_light_position", {"x": 10.0, "y": 20.0} + ) + + def test_set_size_constraints_both(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_size_constraints( + min_size=PhysicalSize(100, 100), + max_size=PhysicalSize(2000, 1500), + ) + args = runtime_mock.window_call.call_args[0] + assert "min_size" in args[2] + assert "max_size" in args[2] + + def test_set_size_constraints_none(self, runtime_mock: MagicMock) -> None: + """Both None still calls (with empty args).""" + WindowProxy("x").set_size_constraints() + runtime_mock.window_call.assert_called_once_with("x", "set_size_constraints", {}) + + def test_monitor_from_point_present(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = { + "name": "M", + "size": {"width": 1, "height": 1}, + "position": {"x": 0, "y": 0}, + "scale_factor": 1.0, + } + result = WindowProxy("x").monitor_from_point(10.0, 20.0) + assert result is not None + assert result.name == "M" + + def test_monitor_from_point_none(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_get.return_value = None + assert WindowProxy("x").monitor_from_point(0, 0) is None + + def test_eval(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").eval("console.log()") + runtime_mock.window_call.assert_called_once_with( + "x", "eval", {"script": "console.log()"} + ) + + def test_eval_with_result(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_call.return_value = "result" + result = WindowProxy("x").eval_with_result("script") + assert result == "result" + # should pass expect_response=True + kwargs = runtime_mock.window_call.call_args.kwargs + assert kwargs.get("expect_response") is True + + def test_navigate(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").navigate("https://x.com") + runtime_mock.window_call.assert_called_once_with( + "x", "navigate", {"url": "https://x.com"} + ) + + def test_set_zoom(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_zoom(1.5) + runtime_mock.window_call.assert_called_once_with("x", "set_zoom", {"scale": 1.5}) + + def test_set_background_color(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_background_color((10, 20, 30, 255)) + args = runtime_mock.window_call.call_args[0] + assert args[1] == "set_background_color" + assert args[2]["color"] == [10, 20, 30, 255] + + def test_set_effects(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_effects({"effects": [Effect.MICA], "state": EffectState.ACTIVE}) + runtime_mock.window_call.assert_called_once() + + def test_set_progress_bar(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_progress_bar({"status": ProgressBarStatus.NORMAL, "progress": 50}) + runtime_mock.window_call.assert_called_once() + + def test_set_badge_count(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_badge_count(5) + runtime_mock.window_call.assert_called_once_with("x", "set_badge_count", {"count": 5}) + + def test_set_overlay_icon_with_bytes(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_overlay_icon(b"png") + args = runtime_mock.window_call.call_args[0] + assert args[1] == "set_overlay_icon" + assert args[2]["icon"] is not None + + def test_set_overlay_icon_none(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").set_overlay_icon(None) + args = runtime_mock.window_call.call_args[0] + assert args[2]["icon"] is None + + def test_cookies(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_call.return_value = [{"name": "n", "value": "v"}] + result = WindowProxy("x").cookies() + assert isinstance(result, list) + assert len(result) == 1 + + def test_cookies_empty(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_call.return_value = None + assert WindowProxy("x").cookies() == [] + + def test_set_cookie(self, runtime_mock: MagicMock) -> None: + cookie = Cookie(name="n", value="v") + WindowProxy("x").set_cookie(cookie) + args = runtime_mock.window_call.call_args[0] + assert args[1] == "set_cookie" + + def test_delete_cookie(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").delete_cookie("n") + runtime_mock.window_call.assert_called_once_with("x", "delete_cookie", {"name": "n"}) + + def test_remove_menu(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").remove_menu() + runtime_mock.window_call.assert_called_once_with("x", "remove_menu", {}) + + def test_hide_menu(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").hide_menu() + runtime_mock.window_call.assert_called_once_with("x", "hide_menu", {}) + + def test_show_menu(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").show_menu() + runtime_mock.window_call.assert_called_once_with("x", "show_menu", {}) + + def test_is_menu_visible(self, runtime_mock: MagicMock) -> None: + runtime_mock.window_call.return_value = True + assert WindowProxy("x").is_menu_visible() is True + + def test_mocked_repr(self, runtime_mock: MagicMock) -> None: + proxy = WindowProxy("test-label") + assert repr(proxy) == "WindowProxy('test-label')" + + +@patch("pywry.window_proxy.runtime") +class TestWindowProxyMockedMenu: + """Test WindowProxy menu interaction methods.""" + + def test_set_menu_with_proxy_object(self, runtime_mock: MagicMock) -> None: + from pywry.menu_proxy import MenuProxy + + menu = MenuProxy("m1") + wp = WindowProxy("x") + with patch("pywry.menu_proxy.runtime") as menu_runtime: + wp.set_menu(menu) + # set_as_window_menu should send via menu_proxy.runtime + menu_runtime.send_command.assert_called_once() + + def test_set_menu_with_dict(self, runtime_mock: MagicMock) -> None: + """A non-MenuProxy menu uses send_command directly.""" + + class FakeMenu: + id = "fake-menu" + + WindowProxy("x").set_menu(FakeMenu()) + runtime_mock.send_command.assert_called_once() + cmd = runtime_mock.send_command.call_args[0][0] + assert cmd["action"] == "menu_set" + assert cmd["menu_id"] == "fake-menu" + assert cmd["target"] == "window" + assert cmd["label"] == "x" + + def test_set_menu_with_str(self, runtime_mock: MagicMock) -> None: + """A bare string menu uses str() conversion.""" + WindowProxy("x").set_menu("my-menu-id") + runtime_mock.send_command.assert_called_once() + cmd = runtime_mock.send_command.call_args[0][0] + assert cmd["menu_id"] == "my-menu-id" + + def test_popup_menu_with_proxy(self, runtime_mock: MagicMock) -> None: + from pywry.menu_proxy import MenuProxy + + menu = MenuProxy("m1") + with patch("pywry.menu_proxy.runtime") as menu_runtime: + WindowProxy("x").popup_menu(menu, x=10.0, y=20.0) + menu_runtime.send_command.assert_called_once() + + def test_popup_menu_with_dict_and_position(self, runtime_mock: MagicMock) -> None: + class FakeMenu: + id = "fake" + + WindowProxy("x").popup_menu(FakeMenu(), x=5.0, y=10.0) + cmd = runtime_mock.send_command.call_args[0][0] + assert cmd["action"] == "menu_popup" + assert cmd["position"] == {"x": 5.0, "y": 10.0} + + def test_popup_menu_no_position(self, runtime_mock: MagicMock) -> None: + class FakeMenu: + id = "fake" + + WindowProxy("x").popup_menu(FakeMenu()) + cmd = runtime_mock.send_command.call_args[0][0] + assert "position" not in cmd + + def test_popup_menu_with_str(self, runtime_mock: MagicMock) -> None: + WindowProxy("x").popup_menu("menu-str") + cmd = runtime_mock.send_command.call_args[0][0] + assert cmd["menu_id"] == "menu-str" + + +class TestTypeSerialization: + """Verify the serializer helpers used by WindowProxy.""" + + def test_serialize_logical_size(self) -> None: + result = serialize_size(LogicalSize(800.0, 600.0)) + assert result["type"] == "Logical" + assert result["width"] == 800.0 + + def test_serialize_physical_size(self) -> None: + result = serialize_size(PhysicalSize(800, 600)) + assert result["type"] == "Physical" + + def test_serialize_logical_position(self) -> None: + result = serialize_position(LogicalPosition(10.0, 20.0)) + assert result["type"] == "Logical" + + def test_serialize_physical_position(self) -> None: + result = serialize_position(PhysicalPosition(10, 20)) + assert result["type"] == "Physical" From 5e18c5a1517eff5b006b75354ad1176d969244a5 Mon Sep 17 00:00:00 2001 From: deeleeramone <> Date: Mon, 15 Jun 2026 17:32:01 -0700 Subject: [PATCH 2/2] improve test coverage --- pywry/tests/test_app.py | 2166 +++++++++++ pywry/tests/test_async_helpers.py | 80 + pywry/tests/test_auth_login_page.py | 61 + pywry/tests/test_auth_pkce.py | 56 + pywry/tests/test_callbacks.py | 556 +++ pywry/tests/test_chat_html.py | 112 + pywry/tests/test_chat_models.py | 407 ++ pywry/tests/test_chat_permissions.py | 104 + pywry/tests/test_chat_providers_anthropic.py | 155 + pywry/tests/test_chat_providers_callback.py | 180 + pywry/tests/test_chat_providers_init.py | 101 + pywry/tests/test_chat_providers_magentic.py | 200 + pywry/tests/test_chat_providers_openai.py | 195 + pywry/tests/test_chat_providers_stdio.py | 1013 +++++ pywry/tests/test_cli.py | 20 +- pywry/tests/test_commands.py | 424 ++ pywry/tests/test_config.py | 2 - pywry/tests/test_deepagent_provider.py | 8 +- pywry/tests/test_grid.py | 6 +- pywry/tests/test_hot_reload.py | 20 +- pywry/tests/test_inline.py | 3661 ++++++++++++++++++ pywry/tests/test_log.py | 162 + pywry/tests/test_main.py | 2572 ++++++++++++ pywry/tests/test_marquee_e2e.py | 6 - pywry/tests/test_mcp_agentic.py | 455 +++ pywry/tests/test_mcp_builders.py | 337 ++ pywry/tests/test_mcp_docs.py | 61 + pywry/tests/test_mcp_handlers.py | 2210 +++++++++++ pywry/tests/test_mcp_install.py | 248 ++ pywry/tests/test_mcp_main.py | 331 ++ pywry/tests/test_mcp_prompts.py | 54 + pywry/tests/test_mcp_resources.py | 278 ++ pywry/tests/test_mcp_server.py | 486 +++ pywry/tests/test_mcp_skills.py | 106 + pywry/tests/test_mcp_state.py | 341 ++ pywry/tests/test_mcp_tools.py | 91 + pywry/tests/test_modal_e2e.py | 3 - pywry/tests/test_notebook.py | 1053 +++++ pywry/tests/test_pyinstaller_hook.py | 53 + pywry/tests/test_runtime.py | 1390 +++++++ pywry/tests/test_state_auth.py | 445 +++ pywry/tests/test_state_base.py | 214 + pywry/tests/test_state_callbacks.py | 275 ++ pywry/tests/test_state_factory.py | 376 ++ pywry/tests/test_state_file.py | 309 ++ pywry/tests/test_state_memory.py | 5 +- pywry/tests/test_state_mixins.py | 4 +- pywry/tests/test_state_redis.py | 4 +- pywry/tests/test_state_server.py | 469 +++ pywry/tests/test_state_sqlite.py | 44 +- pywry/tests/test_state_sync_helpers.py | 360 ++ pywry/tests/test_templates.py | 28 +- pywry/tests/test_toolbar.py | 11 +- pywry/tests/test_tvchart_config.py | 411 ++ pywry/tests/test_tvchart_datafeed.py | 173 + pywry/tests/test_tvchart_e2e.py | 3 +- pywry/tests/test_tvchart_frontend.py | 1016 +++++ pywry/tests/test_tvchart_mixin.py | 1162 ++++++ pywry/tests/test_tvchart_models.py | 586 +++ pywry/tests/test_tvchart_normalize.py | 944 +++++ pywry/tests/test_tvchart_toolbars.py | 289 ++ pywry/tests/test_tvchart_udf.py | 1174 ++++++ pywry/tests/test_types.py | 4 +- pywry/tests/test_watcher.py | 24 +- pywry/tests/test_widget.py | 891 +++++ pywry/tests/test_widget_protocol.py | 5 +- pywry/tests/test_window_manager.py | 1668 ++++++++ pywry/tests/test_window_proxy.py | 20 +- 68 files changed, 30521 insertions(+), 157 deletions(-) create mode 100644 pywry/tests/test_app.py create mode 100644 pywry/tests/test_async_helpers.py create mode 100644 pywry/tests/test_auth_login_page.py create mode 100644 pywry/tests/test_auth_pkce.py create mode 100644 pywry/tests/test_callbacks.py create mode 100644 pywry/tests/test_chat_html.py create mode 100644 pywry/tests/test_chat_models.py create mode 100644 pywry/tests/test_chat_permissions.py create mode 100644 pywry/tests/test_chat_providers_anthropic.py create mode 100644 pywry/tests/test_chat_providers_callback.py create mode 100644 pywry/tests/test_chat_providers_init.py create mode 100644 pywry/tests/test_chat_providers_magentic.py create mode 100644 pywry/tests/test_chat_providers_openai.py create mode 100644 pywry/tests/test_chat_providers_stdio.py create mode 100644 pywry/tests/test_commands.py create mode 100644 pywry/tests/test_inline.py create mode 100644 pywry/tests/test_log.py create mode 100644 pywry/tests/test_main.py create mode 100644 pywry/tests/test_mcp_agentic.py create mode 100644 pywry/tests/test_mcp_builders.py create mode 100644 pywry/tests/test_mcp_docs.py create mode 100644 pywry/tests/test_mcp_handlers.py create mode 100644 pywry/tests/test_mcp_install.py create mode 100644 pywry/tests/test_mcp_main.py create mode 100644 pywry/tests/test_mcp_prompts.py create mode 100644 pywry/tests/test_mcp_resources.py create mode 100644 pywry/tests/test_mcp_server.py create mode 100644 pywry/tests/test_mcp_skills.py create mode 100644 pywry/tests/test_mcp_state.py create mode 100644 pywry/tests/test_mcp_tools.py create mode 100644 pywry/tests/test_notebook.py create mode 100644 pywry/tests/test_pyinstaller_hook.py create mode 100644 pywry/tests/test_runtime.py create mode 100644 pywry/tests/test_state_auth.py create mode 100644 pywry/tests/test_state_base.py create mode 100644 pywry/tests/test_state_callbacks.py create mode 100644 pywry/tests/test_state_factory.py create mode 100644 pywry/tests/test_state_file.py create mode 100644 pywry/tests/test_state_server.py create mode 100644 pywry/tests/test_state_sync_helpers.py create mode 100644 pywry/tests/test_tvchart_config.py create mode 100644 pywry/tests/test_tvchart_datafeed.py create mode 100644 pywry/tests/test_tvchart_frontend.py create mode 100644 pywry/tests/test_tvchart_mixin.py create mode 100644 pywry/tests/test_tvchart_models.py create mode 100644 pywry/tests/test_tvchart_normalize.py create mode 100644 pywry/tests/test_tvchart_toolbars.py create mode 100644 pywry/tests/test_tvchart_udf.py create mode 100644 pywry/tests/test_widget.py create mode 100644 pywry/tests/test_window_manager.py diff --git a/pywry/tests/test_app.py b/pywry/tests/test_app.py new file mode 100644 index 0000000..66d7e6a --- /dev/null +++ b/pywry/tests/test_app.py @@ -0,0 +1,2166 @@ +"""Unit tests for pywry.app.PyWry targeting line coverage. + +These tests deliberately avoid spawning the pytauri subprocess. We mock +``pywry.runtime`` and the ``WindowModeBase`` instance inside the app so +the test exercises: + +* PyWry constructor across all window modes (with/without auto-fallback). +* show / show_plotly / show_dataframe / show_tvchart dispatch in NOTEBOOK, + BROWSER, and native modes. +* Auth API (login, logout, _resolve_provider, is_authenticated, + _show_login_page_and_wait, _wire_logout_handler). +* Native menu / tray helpers (create_menu, create_tray, remove_tray, + _require_native_mode). +* Event emission helpers (emit, alert, send_event, on_*, command). +* Filter/sort helpers used by server-side grid mode (_apply_grid_filter, + _apply_grid_sort, _row_matches_filter, _text_filter_match, + _number_filter_match, _get_sort_key). +* Lifecycle helpers (close, get_labels, is_open, refresh, refresh_css, + enable/disable_hot_reload, destroy, _shutdown, block). +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from pywry.callbacks import get_registry +from pywry.models import HtmlContent, ThemeMode, WindowMode +from pywry.window_manager import ( + BrowserMode, + MultiWindowMode, + NewWindowMode, + SingleWindowMode, + get_lifecycle, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def reset_state(): + """Reset shared state between tests.""" + get_registry().clear() + get_lifecycle().clear() + yield + get_registry().clear() + get_lifecycle().clear() + + +def make_app(mode=WindowMode.NEW_WINDOW, **kwargs): + """Build a PyWry instance without ever starting the subprocess. + + Mocks ``pywry.runtime`` setters so they don't touch real state and + forces ``is_headless_environment`` to False so the constructor doesn't + promote NEW_WINDOW to BROWSER. + """ + from pywry.app import PyWry + + with ( + patch("pywry.app.should_use_inline_rendering", return_value=False), + patch("pywry.app.is_headless_environment", return_value=False), + ): + return PyWry(mode=mode, **kwargs) + + +# --------------------------------------------------------------------------- +# Constructor / mode dispatch +# --------------------------------------------------------------------------- + + +class TestConstructor: + def test_default_init_creates_new_window_mode(self): + app = make_app() + assert app._mode_enum == WindowMode.NEW_WINDOW + assert isinstance(app._mode, NewWindowMode) + assert app.theme == ThemeMode.DARK + + def test_single_window_mode(self): + app = make_app(mode=WindowMode.SINGLE_WINDOW) + assert isinstance(app._mode, SingleWindowMode) + + def test_multi_window_mode(self): + app = make_app(mode=WindowMode.MULTI_WINDOW) + assert isinstance(app._mode, MultiWindowMode) + + def test_browser_mode(self): + app = make_app(mode=WindowMode.BROWSER) + assert isinstance(app._mode, BrowserMode) + + def test_notebook_mode_falls_through_to_multi_window(self): + # _create_mode returns MultiWindowMode for NOTEBOOK + app = make_app(mode=WindowMode.NOTEBOOK) + assert isinstance(app._mode, MultiWindowMode) + + def test_init_with_custom_dimensions(self): + app = make_app(title="X", width=1024, height=768) + assert app._default_config.title == "X" + assert app._default_config.width == 1024 + assert app._default_config.height == 768 + + def test_init_with_hot_reload_flag(self): + with patch("pywry.app.HotReloadManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr_cls.return_value = mock_mgr + app = make_app(hot_reload=True) + assert app._hot_reload_manager is mock_mgr + mock_mgr.start.assert_called_once() + + def test_init_with_settings_hot_reload_enabled(self): + from pywry.config import PyWrySettings + + settings = PyWrySettings() + settings.hot_reload.enabled = True + + with patch("pywry.app.HotReloadManager") as mock_mgr_cls: + mock_mgr = MagicMock() + mock_mgr_cls.return_value = mock_mgr + app = make_app(settings=settings) + assert app._hot_reload_manager is mock_mgr + + def test_auto_fallback_to_browser_on_headless(self): + from pywry.app import PyWry + + with ( + patch("pywry.app.should_use_inline_rendering", return_value=False), + patch("pywry.app.is_headless_environment", return_value=True), + ): + app = PyWry(mode=WindowMode.NEW_WINDOW) + assert app._mode_enum == WindowMode.BROWSER + assert isinstance(app._mode, BrowserMode) + + +class TestUseInline: + def test_notebook_mode_returns_true(self): + app = make_app(mode=WindowMode.NOTEBOOK) + assert app._use_inline() is True + + def test_browser_mode_returns_true(self): + app = make_app(mode=WindowMode.BROWSER) + assert app._use_inline() is True + + def test_native_mode_returns_false_when_not_inline(self): + app = make_app() + # _use_inline() is False because should_use_inline_rendering is False + # at the global module level (we don't re-patch here, but + # should_use_inline_rendering may return True or False — we only + # verify the underlying function is consulted) + with patch("pywry.app.should_use_inline_rendering", return_value=False): + assert app._use_inline() is False + + def test_inline_when_should_use_inline_true(self): + app = make_app() + with patch("pywry.app.should_use_inline_rendering", return_value=True): + assert app._use_inline() is True + + +class TestRegisterInlineWidget: + def test_with_valid_widget(self): + app = make_app(mode=WindowMode.BROWSER) + widget = MagicMock() + widget.label = "w1" + app._register_inline_widget(widget) + assert app._inline_widgets["w1"] is widget + + def test_with_missing_attributes(self): + app = make_app(mode=WindowMode.BROWSER) + widget = object() + app._register_inline_widget(widget) + assert app._inline_widgets == {} + + +# --------------------------------------------------------------------------- +# Theme property and settings property +# --------------------------------------------------------------------------- + + +class TestThemeAndSettings: + def test_theme_setter_updates(self): + app = make_app() + app.theme = ThemeMode.LIGHT + assert app.theme == ThemeMode.LIGHT + + def test_settings_returns_pywrysettings(self): + from pywry.config import PyWrySettings + + app = make_app() + assert isinstance(app.settings, PyWrySettings) + + def test_default_config_returns_window_config(self): + from pywry.models import WindowConfig + + app = make_app() + assert isinstance(app.default_config, WindowConfig) + + def test_set_initialization_script(self): + app = make_app() + app.set_initialization_script("console.log('init');") + assert app._default_config.initialization_script == "console.log('init');" + + +# --------------------------------------------------------------------------- +# Auth API +# --------------------------------------------------------------------------- + + +class TestAuthApi: + def test_is_authenticated_false_by_default(self): + app = make_app() + assert app.is_authenticated is False + + def test_is_authenticated_true_when_result_success(self): + app = make_app() + result = MagicMock() + result.success = True + app._auth_result = result + assert app.is_authenticated is True + + def test_is_authenticated_false_when_result_unsuccessful(self): + app = make_app() + result = MagicMock() + result.success = False + app._auth_result = result + assert app.is_authenticated is False + + def test_resolve_provider_with_oauth_provider_passes_through(self): + from pywry.auth.providers import OAuthProvider + + app = make_app() + provider = MagicMock(spec=OAuthProvider) + result = app._resolve_provider(provider) + assert result is provider + + def test_resolve_provider_none_raises_when_no_settings(self): + from pywry.exceptions import AuthenticationError + + app = make_app() + # No oauth2 settings — _resolve_provider should raise + with patch("pywry.config.OAuth2Settings") as mock_settings_cls: + mock_settings_cls.side_effect = Exception("no env") + with pytest.raises(AuthenticationError): + app._resolve_provider(None) + + def test_resolve_provider_creates_from_settings(self): + from pywry.config import OAuth2Settings + + app = make_app() + app._settings.oauth2 = OAuth2Settings( + provider="google", client_id="cid", client_secret="cs", scopes="openid" + ) + with patch("pywry.auth.providers.create_provider_from_settings") as mock_create: + mock_provider = MagicMock() + mock_create.return_value = mock_provider + result = app._resolve_provider(None) + assert result is mock_provider + + def test_resolve_provider_assigns_oauth_settings_when_missing(self): + """When self._settings.oauth2 is None but env-derived OAuth2Settings has + a client_id, _resolve_provider populates self._settings.oauth2 (line 319).""" + from pywry.config import OAuth2Settings + + app = make_app() + app._settings.oauth2 = None + env_settings = OAuth2Settings( + provider="google", client_id="abc", client_secret="x", scopes="openid" + ) + with ( + patch("pywry.config.OAuth2Settings", return_value=env_settings), + patch( + "pywry.auth.providers.create_provider_from_settings", + return_value="prov", + ), + ): + result = app._resolve_provider(None) + assert result == "prov" + assert app._settings.oauth2 is env_settings + + def test_resolve_provider_non_oauth_uses_create_from_settings(self): + app = make_app() + # Not an OAuthProvider — falls through to create_provider_from_settings + with patch("pywry.auth.providers.create_provider_from_settings") as mock_create: + mock_create.return_value = "stub" + result = app._resolve_provider({"some": "config"}) + assert result == "stub" + + def test_logout_when_not_logged_in(self): + app = make_app() + # Should not raise - alert is mocked via emit + with patch.object(app, "alert"): + app.logout() + assert app._session_manager is None + assert app._auth_result is None + + def test_logout_with_session_manager(self): + app = make_app() + sm = MagicMock() + app._session_manager = sm + app._auth_result = MagicMock() + with ( + patch("pywry.state.sync_helpers.run_async") as mock_run, + patch.object(app, "alert"), + ): + app.logout() + mock_run.assert_called_once() + assert app._session_manager is None + assert app._auth_result is None + + def test_logout_no_alert_when_disabled(self): + app = make_app() + with patch.object(app, "alert") as mock_alert: + app.logout(auto_alert=False) + mock_alert.assert_not_called() + + def test_login_handles_authentication_failure(self): + app = make_app() + with ( + patch.object(app, "_resolve_provider", return_value=MagicMock()), + patch("pywry.auth.flow.AuthFlowManager") as mock_fm_cls, + patch("pywry.auth.session.SessionManager"), + patch("pywry.auth.token_store.get_token_store"), + patch.object(app, "alert"), + ): + mock_fm = MagicMock() + mock_fm.authenticate.side_effect = RuntimeError("auth failed") + mock_fm_cls.return_value = mock_fm + with pytest.raises(RuntimeError): + app.login() + + def test_login_unsuccessful_result(self): + app = make_app() + with ( + patch.object(app, "_resolve_provider", return_value=MagicMock()), + patch("pywry.auth.flow.AuthFlowManager") as mock_fm_cls, + patch("pywry.auth.session.SessionManager"), + patch("pywry.auth.token_store.get_token_store"), + patch.object(app, "alert"), + ): + mock_fm = MagicMock() + result = MagicMock() + result.success = False + result.error = "err" + mock_fm.authenticate.return_value = result + mock_fm_cls.return_value = mock_fm + r = app.login() + assert r is result + + def test_login_successful_with_user_info(self): + app = make_app() + + on_login_calls = [] + + def on_login(res): + on_login_calls.append(res) + + with ( + patch.object(app, "_resolve_provider", return_value=MagicMock()), + patch("pywry.auth.flow.AuthFlowManager") as mock_fm_cls, + patch("pywry.auth.session.SessionManager"), + patch("pywry.auth.token_store.get_token_store"), + patch.object(app, "alert"), + ): + mock_fm = MagicMock() + result = MagicMock() + result.success = True + result.user_info = {"name": "Alice", "email": "a@x.com"} + mock_fm.authenticate.return_value = result + mock_fm_cls.return_value = mock_fm + app.login(on_login=on_login) + assert on_login_calls == [result] + assert app._auth_result is result + + def test_login_successful_without_user_info(self): + app = make_app() + with ( + patch.object(app, "_resolve_provider", return_value=MagicMock()), + patch("pywry.auth.flow.AuthFlowManager") as mock_fm_cls, + patch("pywry.auth.session.SessionManager"), + patch("pywry.auth.token_store.get_token_store"), + patch.object(app, "alert"), + ): + mock_fm = MagicMock() + result = MagicMock() + result.success = True + result.user_info = None + mock_fm.authenticate.return_value = result + mock_fm_cls.return_value = mock_fm + app.login() + assert app._auth_result is result + + def test_login_with_show_page(self): + app = make_app() + with ( + patch.object(app, "_resolve_provider", return_value=MagicMock()), + patch.object(app, "_show_login_page_and_wait") as mock_show_page, + patch("pywry.auth.flow.AuthFlowManager") as mock_fm_cls, + patch("pywry.auth.session.SessionManager"), + patch("pywry.auth.token_store.get_token_store"), + patch.object(app, "alert"), + ): + mock_fm = MagicMock() + result = MagicMock() + result.success = True + result.user_info = {"name": "Bob"} + mock_fm.authenticate.return_value = result + mock_fm_cls.return_value = mock_fm + app.login(show_page=True, page_title="Hi") + mock_show_page.assert_called_once() + + def test_show_login_page_and_wait_uses_threading_event(self): + app = make_app() + provider = MagicMock() + # We replace show with a mock that simulates the click event firing + captured_callbacks = {} + + def fake_show(page, **kwargs): + captured_callbacks.update(kwargs.get("callbacks") or {}) + # Auto-trigger the click handler so wait() returns immediately + from pywry.auth.login_page import LOGIN_CLICK_EVENT + + cb = captured_callbacks.get(LOGIN_CLICK_EVENT) + if cb: + cb({}) + + with patch.object(app, "show", side_effect=fake_show): + app._show_login_page_and_wait(provider, "Sign in") + + def test_wire_logout_handler_registers_for_each_label(self): + app = make_app() + # Stub mode.get_labels + app._mode = MagicMock() + app._mode.get_labels.return_value = ["w1", "w2"] + called = [] + + def on_logout(): + called.append("logged-out") + + registry = get_registry() + original_register = registry.register + + registered_pairs = [] + + def fake_register(label, event_type, fn, **kwargs): + registered_pairs.append((label, event_type)) + return original_register(label, event_type, fn, **kwargs) + + with patch.object(registry, "register", side_effect=fake_register): + app._wire_logout_handler( + provider=MagicMock(), + on_login=None, + on_logout=on_logout, + show_page=False, + page_title="X", + auto_alert=False, + kwargs={}, + ) + assert ("w1", "auth:do-logout") in registered_pairs + assert ("w2", "auth:do-logout") in registered_pairs + + def test_wire_logout_handler_default_main_label_when_no_labels(self): + app = make_app() + app._mode = MagicMock() + app._mode.get_labels.return_value = [] + + registered_pairs = [] + registry = get_registry() + original_register = registry.register + + def fake_register(label, event_type, fn, **kwargs): + registered_pairs.append((label, event_type)) + return original_register(label, event_type, fn, **kwargs) + + with patch.object(registry, "register", side_effect=fake_register): + app._wire_logout_handler( + provider=MagicMock(), + on_login=None, + on_logout=None, + show_page=False, + page_title="X", + auto_alert=False, + kwargs={}, + ) + assert ("main", "auth:do-logout") in registered_pairs + + def test_wire_logout_handler_logout_dispatches(self): + app = make_app() + app._mode = MagicMock() + app._mode.get_labels.return_value = ["w1"] + called = [] + + def on_logout(): + called.append("ok") + + with patch.object(app, "logout") as mock_logout: + app._wire_logout_handler( + provider=MagicMock(), + on_login=None, + on_logout=on_logout, + show_page=False, + page_title="X", + auto_alert=True, + kwargs={}, + ) + # Now dispatch to trigger the inner _handle_logout + registry = get_registry() + registry.dispatch("w1", "auth:do-logout", {}) + assert called == ["ok"] + mock_logout.assert_called_once_with(auto_alert=True) + + def test_wire_logout_handler_re_enters_login_when_show_page(self): + app = make_app() + app._mode = MagicMock() + app._mode.get_labels.return_value = ["w1"] + + with ( + patch.object(app, "logout"), + patch.object(app, "login") as mock_login, + ): + app._wire_logout_handler( + provider=MagicMock(), + on_login=None, + on_logout=None, + show_page=True, + page_title="X", + auto_alert=True, + kwargs={}, + ) + registry = get_registry() + registry.dispatch("w1", "auth:do-logout", {}) + mock_login.assert_called_once() + + +# --------------------------------------------------------------------------- +# Menu / Tray / require_native_mode +# --------------------------------------------------------------------------- + + +class TestNativeRequireMode: + def test_require_native_mode_passes_for_new_window(self): + app = make_app(mode=WindowMode.NEW_WINDOW) + app._require_native_mode("foo()") # No raise + + def test_require_native_mode_passes_for_single_window(self): + app = make_app(mode=WindowMode.SINGLE_WINDOW) + app._require_native_mode("foo()") + + def test_require_native_mode_passes_for_multi_window(self): + app = make_app(mode=WindowMode.MULTI_WINDOW) + app._require_native_mode("foo()") + + def test_require_native_mode_raises_for_browser(self): + app = make_app(mode=WindowMode.BROWSER) + with pytest.raises(RuntimeError, match="requires a native window mode"): + app._require_native_mode("foo()") + + def test_require_native_mode_raises_for_notebook(self): + app = make_app(mode=WindowMode.NOTEBOOK) + with pytest.raises(RuntimeError): + app._require_native_mode("foo()") + + +class TestCreateMenu: + def test_create_menu_native_mode(self): + app = make_app() + with patch("pywry.menu_proxy.MenuProxy.create") as mock_create: + mock_proxy = MagicMock() + mock_create.return_value = mock_proxy + result = app.create_menu("menu1", items=[]) + assert result is mock_proxy + + def test_create_menu_raises_in_browser_mode(self): + app = make_app(mode=WindowMode.BROWSER) + with pytest.raises(RuntimeError): + app.create_menu("menu1") + + +class TestCreateTray: + def test_create_tray_native_mode(self): + app = make_app() + with patch("pywry.tray_proxy.TrayProxy.create") as mock_create: + mock_tray = MagicMock() + mock_create.return_value = mock_tray + result = app.create_tray("t1", tooltip="foo", title="bar", icon=b"X") + assert result is mock_tray + assert app._trays["t1"] is mock_tray + + def test_create_tray_raises_in_browser_mode(self): + app = make_app(mode=WindowMode.BROWSER) + with pytest.raises(RuntimeError): + app.create_tray("t1") + + def test_remove_tray_from_app_registry(self): + app = make_app() + tray = MagicMock() + app._trays["t1"] = tray + app.remove_tray("t1") + tray.remove.assert_called_once() + assert "t1" not in app._trays + + def test_remove_tray_falls_back_to_class_registry(self): + from pywry.tray_proxy import TrayProxy + + app = make_app() + # Not in app._trays, but in class-level registry + class_tray = MagicMock() + with patch.object(TrayProxy, "_all_proxies", {"t1": class_tray}): + app.remove_tray("t1") + class_tray.remove.assert_called_once() + + def test_remove_tray_when_not_found(self): + from pywry.tray_proxy import TrayProxy + + app = make_app() + # Not anywhere — should be a no-op + with patch.object(TrayProxy, "_all_proxies", {}): + app.remove_tray("nonexistent") # No raise + + +# --------------------------------------------------------------------------- +# show() in native + inline modes +# --------------------------------------------------------------------------- + + +class TestShowInline: + def test_show_in_browser_mode_calls_inline_show(self): + app = make_app(mode=WindowMode.BROWSER) + fake_widget = MagicMock() + fake_widget.label = "w-x" + with patch("pywry.inline.show", return_value=fake_widget) as mock_show: + result = app.show("

hi

", title="T", height=400) + assert result is fake_widget + mock_show.assert_called_once() + # open_browser=True for browser mode + kwargs = mock_show.call_args.kwargs + assert kwargs["open_browser"] is True + + def test_show_in_browser_mode_with_html_content(self): + from pywry.models import HtmlContent + + app = make_app(mode=WindowMode.BROWSER) + fake_widget = MagicMock() + fake_widget.label = "w-x" + content = HtmlContent(html="

x

", inline_css="body{}") + with patch("pywry.inline.show", return_value=fake_widget) as mock_show: + app.show(content) + kwargs = mock_show.call_args.kwargs + # inline_css should be prepended to html + assert "