Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 70 additions & 14 deletions agentrun/tool/api/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import httpx

from agentrun.tool.model import ToolInfo, ToolSchema
from agentrun.tool.model import ToolInfo
from agentrun.utils.config import Config
from agentrun.utils.log import logger
from agentrun.utils.ram_signature import get_agentrun_signed_headers

_MCP_METADATA_TIMEOUT_SECONDS = 30.0


def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
Expand All @@ -30,9 +33,6 @@ def _get_or_create_event_loop() -> asyncio.AbstractEventLoop:
return loop


from agentrun.utils.ram_signature import get_agentrun_signed_headers


class _AgentrunRamAuth(httpx.Auth):
"""httpx Auth handler:为每次请求动态生成 RAM 签名。

Expand Down Expand Up @@ -144,6 +144,32 @@ def is_streamable(self) -> bool:
"""是否使用 Streamable HTTP 传输 / Whether to use Streamable HTTP transport"""
return self.session_affinity == "MCP_STREAMABLE"

def _metadata_timeout_seconds(self) -> float:
timeout = self.config.get_timeout()
if timeout and timeout > 0:
return min(float(timeout), _MCP_METADATA_TIMEOUT_SECONDS)
return _MCP_METADATA_TIMEOUT_SECONDS

def _invoke_timeout_seconds(self) -> float:
timeout = self.config.get_timeout()
if timeout and timeout > 0:
return float(timeout)
return 600.0

async def _wait_for_mcp_request(
self,
awaitable: Any,
operation: str,
timeout: float,
) -> Any:
try:
return await asyncio.wait_for(awaitable, timeout=timeout)
except asyncio.TimeoutError as exc:
raise TimeoutError(
f"MCP {operation} timed out after {timeout:g}s for endpoint"
f" {self.endpoint}"
) from exc

def _build_ram_auth(self, url: str) -> tuple:
"""当目标是 agentrun-data 域名时,改写 URL 并返回 httpx Auth handler。

Expand Down Expand Up @@ -199,8 +225,17 @@ async def list_tools_async(self) -> List[ToolInfo]:
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.list_tools()
metadata_timeout = self._metadata_timeout_seconds()
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
metadata_timeout,
)
result = await self._wait_for_mcp_request(
session.list_tools(),
"list_tools",
metadata_timeout,
)
return [
ToolInfo.from_mcp_tool(tool)
for tool in result.tools
Expand All @@ -215,8 +250,17 @@ async def list_tools_async(self) -> List[ToolInfo]:
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.list_tools()
metadata_timeout = self._metadata_timeout_seconds()
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
metadata_timeout,
)
result = await self._wait_for_mcp_request(
session.list_tools(),
"list_tools",
metadata_timeout,
)
return [
ToolInfo.from_mcp_tool(tool)
for tool in result.tools
Expand Down Expand Up @@ -266,9 +310,15 @@ async def call_tool_async(
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.call_tool(
name, arguments=arguments or {}
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
self._metadata_timeout_seconds(),
)
result = await self._wait_for_mcp_request(
session.call_tool(name, arguments=arguments or {}),
f"call_tool {name}",
self._invoke_timeout_seconds(),
)
return result
else:
Expand All @@ -281,9 +331,15 @@ async def call_tool_async(
async with ClientSession(
read_stream, write_stream
) as session:
await session.initialize()
result = await session.call_tool(
name, arguments=arguments or {}
await self._wait_for_mcp_request(
session.initialize(),
"initialize",
self._metadata_timeout_seconds(),
)
result = await self._wait_for_mcp_request(
session.call_tool(name, arguments=arguments or {}),
f"call_tool {name}",
self._invoke_timeout_seconds(),
)
return result
except ImportError:
Expand Down
59 changes: 58 additions & 1 deletion tests/unittests/tool/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
Tests MCP protocol interaction functionality of ToolMCPSession.
"""

import asyncio
import sys
from unittest.mock import AsyncMock, MagicMock, Mock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from agentrun.tool.api.mcp import ToolMCPSession
from agentrun.tool.model import ToolInfo
from agentrun.utils.config import Config


class TestToolMCPSessionInit:
Expand Down Expand Up @@ -186,6 +188,36 @@ def mock_import(name, *args, **kwargs):
sys.modules.update(saved_modules)
assert tools == []

@pytest.mark.asyncio
async def test_list_tools_async_initialize_timeout(self):
"""测试 initialize 无响应时不会无限等待"""

async def never_return():
await asyncio.Event().wait()

mock_session = AsyncMock()
mock_session.initialize = never_return
mock_session.list_tools = AsyncMock()

mock_modules = _setup_mock_mcp_modules(mock_session)

with patch.dict(sys.modules, mock_modules):
with patch(
"agentrun.tool.api.mcp._MCP_METADATA_TIMEOUT_SECONDS",
0.01,
):
session = ToolMCPSession(
endpoint="http://example.com/mcp",
session_affinity="MCP_STREAMABLE",
)

with pytest.raises(
TimeoutError, match="MCP initialize timed out"
):
await session.list_tools_async()

mock_session.list_tools.assert_not_called()


class TestToolMCPSessionListTools:
"""测试 list_tools 同步方法"""
Expand Down Expand Up @@ -258,6 +290,31 @@ async def test_call_tool_async_sse_mode(self):

assert result == mock_call_result

@pytest.mark.asyncio
async def test_call_tool_async_timeout(self):
"""测试工具调用无响应时会按 Config.timeout 退出"""

async def never_return(*args, **kwargs):
await asyncio.Event().wait()

mock_session = AsyncMock()
mock_session.initialize = AsyncMock()
mock_session.call_tool = never_return

mock_modules = _setup_mock_mcp_modules(mock_session)

with patch.dict(sys.modules, mock_modules):
session = ToolMCPSession(
endpoint="http://example.com/mcp",
session_affinity="MCP_STREAMABLE",
config=Config(timeout=0.01),
)

with pytest.raises(
TimeoutError, match="MCP call_tool test_tool timed out"
):
await session.call_tool_async("test_tool", {"key": "val"})

@pytest.mark.asyncio
async def test_call_tool_async_import_error(self):
"""测试 mcp 未安装时抛出 ImportError"""
Expand Down
Loading