|
| 1 | +from contextlib import AsyncExitStack |
1 | 2 | from unittest.mock import AsyncMock, patch |
2 | 3 |
|
3 | 4 | import pytest |
@@ -67,3 +68,111 @@ async def test_manual_connect_disconnect_works( |
67 | 68 |
|
68 | 69 | await server.cleanup() |
69 | 70 | assert server.session is None, "Server should be disconnected" |
| 71 | + |
| 72 | + |
| 73 | +@pytest.mark.asyncio |
| 74 | +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) |
| 75 | +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) |
| 76 | +async def test_cleanup_resets_state_for_reconnection(mock_initialize: AsyncMock, mock_stdio_client): |
| 77 | + """Test that cleanup resets all session state so the same instance can reconnect.""" |
| 78 | + server = MCPServerStdio( |
| 79 | + params={"command": tee}, |
| 80 | + cache_tools_list=True, |
| 81 | + ) |
| 82 | + |
| 83 | + await server.connect() |
| 84 | + first_exit_stack = server.exit_stack |
| 85 | + assert server.session is not None |
| 86 | + assert server.server_initialize_result is not None or mock_initialize.return_value is None |
| 87 | + |
| 88 | + await server.cleanup() |
| 89 | + |
| 90 | + # All session state must be cleared |
| 91 | + assert server.session is None |
| 92 | + assert server.server_initialize_result is None |
| 93 | + assert server._get_session_id is None |
| 94 | + # Exit stack must be a fresh instance so a subsequent connect() works |
| 95 | + assert isinstance(server.exit_stack, AsyncExitStack) |
| 96 | + assert server.exit_stack is not first_exit_stack |
| 97 | + |
| 98 | + |
| 99 | +@pytest.mark.asyncio |
| 100 | +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) |
| 101 | +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) |
| 102 | +@patch("mcp.client.session.ClientSession.list_tools") |
| 103 | +async def test_reconnect_after_cleanup( |
| 104 | + mock_list_tools: AsyncMock, mock_initialize: AsyncMock, mock_stdio_client |
| 105 | +): |
| 106 | + """Test that an MCPServerStdio instance can reconnect after cleanup.""" |
| 107 | + server = MCPServerStdio( |
| 108 | + params={"command": tee}, |
| 109 | + cache_tools_list=True, |
| 110 | + ) |
| 111 | + |
| 112 | + tools = [MCPTool(name="tool1", inputSchema={})] |
| 113 | + mock_list_tools.return_value = ListToolsResult(tools=tools) |
| 114 | + |
| 115 | + # First connection cycle |
| 116 | + await server.connect() |
| 117 | + result = await server.list_tools() |
| 118 | + assert len(result) == 1 |
| 119 | + await server.cleanup() |
| 120 | + assert server.session is None |
| 121 | + |
| 122 | + # Second connection cycle on the same instance |
| 123 | + await server.connect() |
| 124 | + assert server.session is not None |
| 125 | + result = await server.list_tools() |
| 126 | + assert len(result) == 1 |
| 127 | + await server.cleanup() |
| 128 | + assert server.session is None |
| 129 | + |
| 130 | + |
| 131 | +@pytest.mark.asyncio |
| 132 | +@patch("mcp.client.stdio.stdio_client", return_value=DummyStreamsContextManager()) |
| 133 | +@patch("mcp.client.session.ClientSession.initialize", new_callable=AsyncMock, return_value=None) |
| 134 | +async def test_cleanup_closes_write_stream_before_exit_stack( |
| 135 | + mock_initialize: AsyncMock, mock_stdio_client |
| 136 | +): |
| 137 | + """Test that cleanup closes the session write stream before unwinding the exit stack. |
| 138 | +
|
| 139 | + This ordering ensures the subprocess receives EOF on stdin and can shut down |
| 140 | + gracefully (releasing multiprocessing semaphores, etc.) before task-group |
| 141 | + cancellation kills the reader/writer coroutines inside the transport. |
| 142 | + """ |
| 143 | + server = MCPServerStdio( |
| 144 | + params={"command": tee}, |
| 145 | + ) |
| 146 | + |
| 147 | + await server.connect() |
| 148 | + assert server.session is not None |
| 149 | + |
| 150 | + # Track the order of operations during cleanup |
| 151 | + call_order: list[str] = [] |
| 152 | + original_aclose = server.session._write_stream.aclose |
| 153 | + |
| 154 | + async def tracked_write_stream_close(): |
| 155 | + call_order.append("write_stream_closed") |
| 156 | + return await original_aclose() |
| 157 | + |
| 158 | + original_exit_stack_aclose = server.exit_stack.aclose |
| 159 | + |
| 160 | + async def tracked_exit_stack_aclose(): |
| 161 | + call_order.append("exit_stack_closed") |
| 162 | + return await original_exit_stack_aclose() |
| 163 | + |
| 164 | + server.session._write_stream.aclose = tracked_write_stream_close # type: ignore[assignment] |
| 165 | + server.exit_stack.aclose = tracked_exit_stack_aclose # type: ignore[assignment] |
| 166 | + |
| 167 | + await server.cleanup() |
| 168 | + |
| 169 | + # The write stream may be closed multiple times (our explicit close, then again |
| 170 | + # during exit_stack unwind by ClientSession.__aexit__). The critical invariant |
| 171 | + # is that the FIRST close happens before the exit_stack unwind begins. |
| 172 | + assert len(call_order) >= 2, f"Expected at least 2 calls, got: {call_order}" |
| 173 | + assert call_order[0] == "write_stream_closed", ( |
| 174 | + f"Write stream must be closed first, got: {call_order}" |
| 175 | + ) |
| 176 | + assert call_order[1] == "exit_stack_closed", ( |
| 177 | + f"Exit stack must be closed after write stream, got: {call_order}" |
| 178 | + ) |
0 commit comments