diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py index 1d34a5546b..31eb58b5b9 100644 --- a/src/mcp/server/auth/middleware/auth_context.py +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -1,5 +1,8 @@ import contextvars +from contextvars import Token +from typing import Any +from starlette.requests import Request from starlette.types import ASGIApp, Receive, Scope, Send from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser @@ -20,6 +23,30 @@ def get_access_token() -> AccessToken | None: return auth_user.access_token if auth_user else None +def push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None: + """Set auth context for the current task from an incoming request. + + This is primarily used by server transports where request handlers may run + in background tasks that are not part of the original ASGI request task. + """ + if request is None: + return None + # Avoid Request.user, which asserts AuthenticationMiddleware is installed. + user: Any | None = request.scope.get("user") + if user is None: + try: + user = getattr(request, "user", None) + except AssertionError: + user = None + return auth_context_var.set(user if isinstance(user, AuthenticatedUser) else None) + + +def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None: + if token is None: + return + auth_context_var.reset(token) + + class AuthContextMiddleware: """Middleware that extracts the authenticated user from the request and sets it in a contextvar for easy access throughout the request lifecycle. diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index 4f8d23b8dd..aacb9043d6 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -26,6 +26,7 @@ from pydantic import BaseModel, ValidationError from typing_extensions import TypeVar +from mcp.server.auth.middleware.auth_context import pop_auth_context, push_auth_context_from_request from mcp.server.connection import Connection from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.models import InitializationOptions @@ -259,7 +260,11 @@ async def _inner() -> HandlerResult: return result call = self._compose_server_middleware(ctx, method, params, _inner) - result = _dump_result(await call()) + auth_token = push_auth_context_from_request(ctx.request) + try: + result = _dump_result(await call()) + finally: + pop_auth_context(auth_token) if method == "initialize": # Commit only on chain success, so a middleware veto leaves no state. # Race-free: the read loop is parked until this call returns. diff --git a/tests/server/auth/test_get_access_token_streamable_http.py b/tests/server/auth/test_get_access_token_streamable_http.py new file mode 100644 index 0000000000..9edf068477 --- /dev/null +++ b/tests/server/auth/test_get_access_token_streamable_http.py @@ -0,0 +1,97 @@ +import time + +import httpx +import pytest +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.routing import Mount + +from mcp import Client +from mcp.client.streamable_http import streamable_http_client +from mcp.server import Server, ServerRequestContext +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.provider import AccessToken +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.types import ( + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, + Tool, +) + + +class _EchoTokenVerifier: + """Accepts any bearer token and echoes it back as the verified AccessToken.""" + + async def verify_token(self, token: str) -> AccessToken | None: + return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600) + + +async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: + access = get_access_token() + text = access.token if access else "" + return CallToolResult(content=[TextContent(type="text", text=text)]) + + +async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})]) + + +class _MutableBearerAuth(httpx.Auth): + def __init__(self, token: str | None) -> None: + self.token = token + + def auth_flow(self, request: httpx.Request): + if self.token is not None: + request.headers["Authorization"] = f"Bearer {self.token}" + yield request + + +async def _call_whoami(asgi_app: Starlette, host: str, token: str | None) -> str: + auth = _MutableBearerAuth(token) + async with ( + httpx.ASGITransport(asgi_app) as transport, + httpx.AsyncClient( + transport=transport, + base_url=f"http://{host}", + auth=auth, + timeout=httpx.Timeout(30, read=30), + follow_redirects=True, + ) as http_client, + ): + transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client) + async with Client(transport_ctx) as client: # pragma: no branch + result = await client.call_tool("whoami", {}) + assert isinstance(result.content[0], TextContent) + return result.content[0].text + + +@pytest.mark.anyio +async def test_get_access_token_reflects_current_request_in_stateful_session() -> None: + host = "testserver" + + server = Server( + "auth-test-server", + on_call_tool=_handle_whoami, + on_list_tools=_handle_list_tools, + ) + + session_manager = StreamableHTTPSessionManager(app=server, stateless=False) + + asgi_app = Starlette( + routes=[Mount("/mcp", app=session_manager.handle_request)], + middleware=[ + Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())), + Middleware(AuthContextMiddleware), + ], + lifespan=lambda app: session_manager.run(), + ) + + async with asgi_app.router.lifespan_context(asgi_app): + assert await _call_whoami(asgi_app, host, "token-A") == "token-A" + assert await _call_whoami(asgi_app, host, "token-B") == "token-B" + assert await _call_whoami(asgi_app, host, None) == ""