diff --git a/agentrun/tool/api/mcp.py b/agentrun/tool/api/mcp.py index 9038d17..505eed0 100644 --- a/agentrun/tool/api/mcp.py +++ b/agentrun/tool/api/mcp.py @@ -10,9 +10,12 @@ import httpx -from agentrun.tool.model import ToolInfo, ToolSchema +from agentrun.tool.model import ToolInfo from agentrun.utils.config import Config from agentrun.utils.log import logger +from agentrun.utils.ram_signature import get_agentrun_signed_headers + +_MCP_METADATA_TIMEOUT_SECONDS = 30.0 def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: @@ -30,9 +33,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: return loop -from agentrun.utils.ram_signature import get_agentrun_signed_headers - - class _AgentrunRamAuth(httpx.Auth): """httpx Auth handler:为每次请求动态生成 RAM 签名。 @@ -144,6 +144,32 @@ def is_streamable(self) -> bool: """是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport""" return self.session_affinity == "MCP_STREAMABLE" + def _metadata_timeout_seconds(self) -> float: + timeout = self.config.get_timeout() + if timeout and timeout > 0: + return min(float(timeout), _MCP_METADATA_TIMEOUT_SECONDS) + return _MCP_METADATA_TIMEOUT_SECONDS + + def _invoke_timeout_seconds(self) -> float: + timeout = self.config.get_timeout() + if timeout and timeout > 0: + return float(timeout) + return 600.0 + + async def _wait_for_mcp_request( + self, + awaitable: Any, + operation: str, + timeout: float, + ) -> Any: + try: + return await asyncio.wait_for(awaitable, timeout=timeout) + except asyncio.TimeoutError as exc: + raise TimeoutError( + f"MCP {operation} timed out after {timeout:g}s for endpoint" + f" {self.endpoint}" + ) from exc + def _build_ram_auth(self, url: str) -> tuple: """当目标是 agentrun-data 域名时,改写 URL 并返回 httpx Auth handler。 @@ -199,8 +225,17 @@ async def list_tools_async(self) -> List[ToolInfo]: async with ClientSession( read_stream, write_stream ) as session: - await session.initialize() - result = await session.list_tools() + metadata_timeout = self._metadata_timeout_seconds() + await self._wait_for_mcp_request( + session.initialize(), + "initialize", + metadata_timeout, + ) + result = await self._wait_for_mcp_request( + session.list_tools(), + "list_tools", + metadata_timeout, + ) return [ ToolInfo.from_mcp_tool(tool) for tool in result.tools @@ -215,8 +250,17 @@ async def list_tools_async(self) -> List[ToolInfo]: async with ClientSession( read_stream, write_stream ) as session: - await session.initialize() - result = await session.list_tools() + metadata_timeout = self._metadata_timeout_seconds() + await self._wait_for_mcp_request( + session.initialize(), + "initialize", + metadata_timeout, + ) + result = await self._wait_for_mcp_request( + session.list_tools(), + "list_tools", + metadata_timeout, + ) return [ ToolInfo.from_mcp_tool(tool) for tool in result.tools @@ -266,9 +310,15 @@ async def call_tool_async( async with ClientSession( read_stream, write_stream ) as session: - await session.initialize() - result = await session.call_tool( - name, arguments=arguments or {} + await self._wait_for_mcp_request( + session.initialize(), + "initialize", + self._metadata_timeout_seconds(), + ) + result = await self._wait_for_mcp_request( + session.call_tool(name, arguments=arguments or {}), + f"call_tool {name}", + self._invoke_timeout_seconds(), ) return result else: @@ -281,9 +331,15 @@ async def call_tool_async( async with ClientSession( read_stream, write_stream ) as session: - await session.initialize() - result = await session.call_tool( - name, arguments=arguments or {} + await self._wait_for_mcp_request( + session.initialize(), + "initialize", + self._metadata_timeout_seconds(), + ) + result = await self._wait_for_mcp_request( + session.call_tool(name, arguments=arguments or {}), + f"call_tool {name}", + self._invoke_timeout_seconds(), ) return result except ImportError: diff --git a/tests/unittests/tool/test_mcp.py b/tests/unittests/tool/test_mcp.py index 907c007..e3aa853 100644 --- a/tests/unittests/tool/test_mcp.py +++ b/tests/unittests/tool/test_mcp.py @@ -4,13 +4,15 @@ Tests MCP protocol interaction functionality of ToolMCPSession. """ +import asyncio import sys -from unittest.mock import AsyncMock, MagicMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentrun.tool.api.mcp import ToolMCPSession from agentrun.tool.model import ToolInfo +from agentrun.utils.config import Config class TestToolMCPSessionInit: @@ -186,6 +188,36 @@ def mock_import(name, *args, **kwargs): sys.modules.update(saved_modules) assert tools == [] + @pytest.mark.asyncio + async def test_list_tools_async_initialize_timeout(self): + """测试 initialize 无响应时不会无限等待""" + + async def never_return(): + await asyncio.Event().wait() + + mock_session = AsyncMock() + mock_session.initialize = never_return + mock_session.list_tools = AsyncMock() + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + with patch( + "agentrun.tool.api.mcp._MCP_METADATA_TIMEOUT_SECONDS", + 0.01, + ): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + ) + + with pytest.raises( + TimeoutError, match="MCP initialize timed out" + ): + await session.list_tools_async() + + mock_session.list_tools.assert_not_called() + class TestToolMCPSessionListTools: """测试 list_tools 同步方法""" @@ -258,6 +290,31 @@ async def test_call_tool_async_sse_mode(self): assert result == mock_call_result + @pytest.mark.asyncio + async def test_call_tool_async_timeout(self): + """测试工具调用无响应时会按 Config.timeout 退出""" + + async def never_return(*args, **kwargs): + await asyncio.Event().wait() + + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = never_return + + mock_modules = _setup_mock_mcp_modules(mock_session) + + with patch.dict(sys.modules, mock_modules): + session = ToolMCPSession( + endpoint="http://example.com/mcp", + session_affinity="MCP_STREAMABLE", + config=Config(timeout=0.01), + ) + + with pytest.raises( + TimeoutError, match="MCP call_tool test_tool timed out" + ): + await session.call_tool_async("test_tool", {"key": "val"}) + @pytest.mark.asyncio async def test_call_tool_async_import_error(self): """测试 mcp 未安装时抛出 ImportError"""