diff --git a/src/acp/connection.py b/src/acp/connection.py index ff1cb19..1ca7763 100644 --- a/src/acp/connection.py +++ b/src/acp/connection.py @@ -58,6 +58,10 @@ class StreamEvent: StreamObserver = Callable[[StreamEvent], Awaitable[None] | None] +class _OversizedLineSkipped(Exception): + """Raised after an oversized line has been discarded.""" + + class Connection: """Minimal JSON-RPC 2.0 connection over newline-delimited JSON frames.""" @@ -153,7 +157,15 @@ async def send_notification(self, method: str, params: JsonValue | None = None) async def _receive_loop(self) -> None: try: while True: - line = await asyncio.wait_for(self._reader.readline(), timeout=self._receive_timeout) + try: + line = await self._read_line() + except _OversizedLineSkipped: + logging.warning( + "Skipped oversized JSON-RPC frame that exceeded the StreamReader line limit. " + "The connection will continue with subsequent frames. If large frames are expected, " + "increase the StreamReader limit, for example via stdio_buffer_limit_bytes when using run_agent." + ) + continue if not line: break line = line.strip() @@ -172,6 +184,36 @@ async def _receive_loop(self) -> None: raise RequestError.internal_error({"details": "Agent timeout"}) from None self._disconnect() + async def _read_line(self) -> bytes: + try: + return await self._wait_for_reader(self._reader.readuntil(b"\n")) + except asyncio.IncompleteReadError as exc: + return exc.partial + except asyncio.LimitOverrunError as exc: + await self._discard_oversized_line(exc.consumed) + raise _OversizedLineSkipped from exc + + async def _discard_oversized_line(self, consumed: int) -> None: + while True: + if consumed <= 0: + consumed = 1 + if consumed > 0: + try: + await self._wait_for_reader(self._reader.readexactly(consumed)) + except asyncio.IncompleteReadError: + return + try: + await self._wait_for_reader(self._reader.readuntil(b"\n")) + except asyncio.IncompleteReadError: + return + except asyncio.LimitOverrunError as exc: + consumed = exc.consumed + else: + return + + async def _wait_for_reader(self, awaitable: Awaitable[bytes]) -> bytes: + return await asyncio.wait_for(awaitable, timeout=self._receive_timeout) + async def _process_message(self, message: dict[str, Any]) -> None: method = message.get("method") has_id = "id" in message diff --git a/tests/real_user/test_stdio_limits.py b/tests/real_user/test_stdio_limits.py index f972a8f..050aaa7 100644 --- a/tests/real_user/test_stdio_limits.py +++ b/tests/real_user/test_stdio_limits.py @@ -69,10 +69,10 @@ async def list_capabilities(self): [sys.executable, small_agent], input=large_msg, capture_output=True, text=True, timeout=2 ) - # Should have errors in stderr about the buffer limit - assert "Error" in result.stderr or result.returncode != 0, ( - f"Expected error with small buffer, got: {result.stderr}" - ) + assert result.returncode == 0 + assert "LimitOverrunError" not in result.stderr + assert "Separator is found, but chunk is longer than limit" not in result.stderr + assert "oversized JSON-RPC frame" in result.stderr # Test 2: Large buffer (200KB) succeeds with large message (70KB) large_agent = os.path.join(tmpdir, "large_agent.py") diff --git a/tests/test_connection_recovery.py b/tests/test_connection_recovery.py new file mode 100644 index 0000000..dae5a32 --- /dev/null +++ b/tests/test_connection_recovery.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from acp.connection import Connection +from acp.exceptions import RequestError + + +async def _noop_handler(method: str, params: Any, is_notification: bool) -> Any: + return None + + +def _make_connection( + *, + limit: int = 128, + receive_timeout: float | None = None, +) -> tuple[Connection, asyncio.StreamReader]: + reader = asyncio.StreamReader(limit=limit) + transport = MagicMock() + transport.is_closing.return_value = False + protocol = AsyncMock() + writer = asyncio.StreamWriter(transport, protocol, reader, asyncio.get_running_loop()) + conn = Connection(_noop_handler, writer, reader, listening=False, receive_timeout=receive_timeout) + return conn, reader + + +@pytest.mark.asyncio +async def test_receive_loop_recovers_from_oversized_frame(caplog: pytest.LogCaptureFixture) -> None: + conn, reader = _make_connection(limit=128) + processed: list[str] = [] + + async def tracking_process(message: dict[str, Any]) -> None: + processed.append(message["method"]) + + conn._process_message = tracking_process # type: ignore[method-assign] + oversized = {"jsonrpc": "2.0", "method": "too-large", "params": {"data": "X" * 256}} + survivor = {"jsonrpc": "2.0", "method": "survivor"} + reader.feed_data(json.dumps(oversized).encode() + b"\n" + json.dumps(survivor).encode() + b"\n") + reader.feed_eof() + + with caplog.at_level("WARNING"): + await conn._receive_loop() + await conn.close() + + assert processed == ["survivor"] + assert any("oversized JSON-RPC frame" in record.message for record in caplog.records) + + +@pytest.mark.asyncio +async def test_receive_loop_recovers_from_consecutive_oversized_frames() -> None: + conn, reader = _make_connection(limit=128) + processed: list[str] = [] + + async def tracking_process(message: dict[str, Any]) -> None: + processed.append(message["method"]) + + conn._process_message = tracking_process # type: ignore[method-assign] + for index in range(2): + oversized = {"jsonrpc": "2.0", "method": f"too-large-{index}", "params": {"data": "Y" * 256}} + reader.feed_data(json.dumps(oversized).encode() + b"\n") + survivor = {"jsonrpc": "2.0", "method": "survivor"} + reader.feed_data(json.dumps(survivor).encode() + b"\n") + reader.feed_eof() + + await conn._receive_loop() + await conn.close() + + assert processed == ["survivor"] + + +@pytest.mark.asyncio +async def test_receive_loop_handles_eof_during_oversized_frame() -> None: + conn, reader = _make_connection(limit=64) + reader.feed_data(b"X" * 256) + reader.feed_eof() + + await conn._receive_loop() + await conn.close() + + assert conn._disconnected is True + + +@pytest.mark.asyncio +async def test_receive_loop_keeps_timeout_semantics() -> None: + conn, _reader = _make_connection(receive_timeout=0.01) + + with pytest.raises(RequestError) as exc_info: + await conn._receive_loop() + await conn.close() + + exc = exc_info.value + assert isinstance(exc, RequestError) + assert str(exc) == "Internal error" + assert exc.data == {"details": "Agent timeout"} + + +@pytest.mark.asyncio +async def test_receive_loop_does_not_swallow_unrelated_reader_error() -> None: + conn, reader = _make_connection() + reader.set_exception(ValueError("reader failed")) + + with pytest.raises(ValueError, match="reader failed"): + await conn._receive_loop() + await conn.close()