Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/acp/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class StreamEvent:
StreamObserver = Callable[[StreamEvent], Awaitable[None] | None]


class _OversizedLineSkipped(Exception):
"""Raised after an oversized line has been discarded."""


class Connection:
"""Minimal JSON-RPC 2.0 connection over newline-delimited JSON frames."""

Expand Down Expand Up @@ -153,7 +157,15 @@ async def send_notification(self, method: str, params: JsonValue | None = None)
async def _receive_loop(self) -> None:
try:
while True:
line = await asyncio.wait_for(self._reader.readline(), timeout=self._receive_timeout)
try:
line = await self._read_line()
except _OversizedLineSkipped:
logging.warning(
"Skipped oversized JSON-RPC frame that exceeded the StreamReader line limit. "
"The connection will continue with subsequent frames. If large frames are expected, "
"increase the StreamReader limit, for example via stdio_buffer_limit_bytes when using run_agent."
)
continue
if not line:
break
line = line.strip()
Expand All @@ -172,6 +184,36 @@ async def _receive_loop(self) -> None:
raise RequestError.internal_error({"details": "Agent timeout"}) from None
self._disconnect()

async def _read_line(self) -> bytes:
try:
return await self._wait_for_reader(self._reader.readuntil(b"\n"))
except asyncio.IncompleteReadError as exc:
return exc.partial
except asyncio.LimitOverrunError as exc:
await self._discard_oversized_line(exc.consumed)
raise _OversizedLineSkipped from exc

async def _discard_oversized_line(self, consumed: int) -> None:
while True:
if consumed <= 0:
consumed = 1
if consumed > 0:
try:
await self._wait_for_reader(self._reader.readexactly(consumed))
except asyncio.IncompleteReadError:
return
try:
await self._wait_for_reader(self._reader.readuntil(b"\n"))
except asyncio.IncompleteReadError:
return
except asyncio.LimitOverrunError as exc:
consumed = exc.consumed
else:
return

async def _wait_for_reader(self, awaitable: Awaitable[bytes]) -> bytes:
return await asyncio.wait_for(awaitable, timeout=self._receive_timeout)

async def _process_message(self, message: dict[str, Any]) -> None:
method = message.get("method")
has_id = "id" in message
Expand Down
8 changes: 4 additions & 4 deletions tests/real_user/test_stdio_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ async def list_capabilities(self):
[sys.executable, small_agent], input=large_msg, capture_output=True, text=True, timeout=2
)

# Should have errors in stderr about the buffer limit
assert "Error" in result.stderr or result.returncode != 0, (
f"Expected error with small buffer, got: {result.stderr}"
)
assert result.returncode == 0
assert "LimitOverrunError" not in result.stderr
assert "Separator is found, but chunk is longer than limit" not in result.stderr
assert "oversized JSON-RPC frame" in result.stderr

# Test 2: Large buffer (200KB) succeeds with large message (70KB)
large_agent = os.path.join(tmpdir, "large_agent.py")
Expand Down
109 changes: 109 additions & 0 deletions tests/test_connection_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import asyncio
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock

import pytest

from acp.connection import Connection
from acp.exceptions import RequestError


async def _noop_handler(method: str, params: Any, is_notification: bool) -> Any:
return None


def _make_connection(
*,
limit: int = 128,
receive_timeout: float | None = None,
) -> tuple[Connection, asyncio.StreamReader]:
reader = asyncio.StreamReader(limit=limit)
transport = MagicMock()
transport.is_closing.return_value = False
protocol = AsyncMock()
writer = asyncio.StreamWriter(transport, protocol, reader, asyncio.get_running_loop())
conn = Connection(_noop_handler, writer, reader, listening=False, receive_timeout=receive_timeout)
return conn, reader


@pytest.mark.asyncio
async def test_receive_loop_recovers_from_oversized_frame(caplog: pytest.LogCaptureFixture) -> None:
conn, reader = _make_connection(limit=128)
processed: list[str] = []

async def tracking_process(message: dict[str, Any]) -> None:
processed.append(message["method"])

conn._process_message = tracking_process # type: ignore[method-assign]
oversized = {"jsonrpc": "2.0", "method": "too-large", "params": {"data": "X" * 256}}
survivor = {"jsonrpc": "2.0", "method": "survivor"}
reader.feed_data(json.dumps(oversized).encode() + b"\n" + json.dumps(survivor).encode() + b"\n")
reader.feed_eof()

with caplog.at_level("WARNING"):
await conn._receive_loop()
await conn.close()

assert processed == ["survivor"]
assert any("oversized JSON-RPC frame" in record.message for record in caplog.records)


@pytest.mark.asyncio
async def test_receive_loop_recovers_from_consecutive_oversized_frames() -> None:
conn, reader = _make_connection(limit=128)
processed: list[str] = []

async def tracking_process(message: dict[str, Any]) -> None:
processed.append(message["method"])

conn._process_message = tracking_process # type: ignore[method-assign]
for index in range(2):
oversized = {"jsonrpc": "2.0", "method": f"too-large-{index}", "params": {"data": "Y" * 256}}
reader.feed_data(json.dumps(oversized).encode() + b"\n")
survivor = {"jsonrpc": "2.0", "method": "survivor"}
reader.feed_data(json.dumps(survivor).encode() + b"\n")
reader.feed_eof()

await conn._receive_loop()
await conn.close()

assert processed == ["survivor"]


@pytest.mark.asyncio
async def test_receive_loop_handles_eof_during_oversized_frame() -> None:
conn, reader = _make_connection(limit=64)
reader.feed_data(b"X" * 256)
reader.feed_eof()

await conn._receive_loop()
await conn.close()

assert conn._disconnected is True


@pytest.mark.asyncio
async def test_receive_loop_keeps_timeout_semantics() -> None:
conn, _reader = _make_connection(receive_timeout=0.01)

with pytest.raises(RequestError) as exc_info:
await conn._receive_loop()
await conn.close()

exc = exc_info.value
assert isinstance(exc, RequestError)
assert str(exc) == "Internal error"
assert exc.data == {"details": "Agent timeout"}


@pytest.mark.asyncio
async def test_receive_loop_does_not_swallow_unrelated_reader_error() -> None:
conn, reader = _make_connection()
reader.set_exception(ValueError("reader failed"))

with pytest.raises(ValueError, match="reader failed"):
await conn._receive_loop()
await conn.close()
Loading