Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/mcp/server/auth/middleware/auth_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import contextvars
from contextvars import Token
from typing import Any

from starlette.requests import Request
from starlette.types import ASGIApp, Receive, Scope, Send

from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
Expand All @@ -20,6 +23,30 @@ def get_access_token() -> AccessToken | None:
return auth_user.access_token if auth_user else None


def push_auth_context_from_request(request: Request | None) -> Token[AuthenticatedUser | None] | None:
"""Set auth context for the current task from an incoming request.

This is primarily used by server transports where request handlers may run
in background tasks that are not part of the original ASGI request task.
"""
if request is None:
return None
# Avoid Request.user, which asserts AuthenticationMiddleware is installed.
user: Any | None = request.scope.get("user")
if user is None:
try:
user = getattr(request, "user", None)
except AssertionError:
user = None
return auth_context_var.set(user if isinstance(user, AuthenticatedUser) else None)


def pop_auth_context(token: Token[AuthenticatedUser | None] | None) -> None:
if token is None:
return
auth_context_var.reset(token)


class AuthContextMiddleware:
"""Middleware that extracts the authenticated user from the request
and sets it in a contextvar for easy access throughout the request lifecycle.
Expand Down
7 changes: 6 additions & 1 deletion src/mcp/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pydantic import BaseModel, ValidationError
from typing_extensions import TypeVar

from mcp.server.auth.middleware.auth_context import pop_auth_context, push_auth_context_from_request
from mcp.server.connection import Connection
from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -259,7 +260,11 @@ async def _inner() -> HandlerResult:
return result

call = self._compose_server_middleware(ctx, method, params, _inner)
result = _dump_result(await call())
auth_token = push_auth_context_from_request(ctx.request)
try:
result = _dump_result(await call())
finally:
pop_auth_context(auth_token)
if method == "initialize":
# Commit only on chain success, so a middleware veto leaves no state.
# Race-free: the read loop is parked until this call returns.
Expand Down
97 changes: 97 additions & 0 deletions tests/server/auth/test_get_access_token_streamable_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import time

import httpx
import pytest
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.routing import Mount

from mcp import Client
from mcp.client.streamable_http import streamable_http_client
from mcp.server import Server, ServerRequestContext
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware, get_access_token
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend
from mcp.server.auth.provider import AccessToken
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.types import (
CallToolRequestParams,
CallToolResult,
ListToolsResult,
PaginatedRequestParams,
TextContent,
Tool,
)


class _EchoTokenVerifier:
"""Accepts any bearer token and echoes it back as the verified AccessToken."""

async def verify_token(self, token: str) -> AccessToken | None:
return AccessToken(token=token, client_id=token, scopes=[], expires_at=int(time.time()) + 3600)


async def _handle_whoami(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
access = get_access_token()
text = access.token if access else "<none>"
return CallToolResult(content=[TextContent(type="text", text=text)])


async def _handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
return ListToolsResult(tools=[Tool(name="whoami", input_schema={"type": "object", "properties": {}})])


class _MutableBearerAuth(httpx.Auth):
def __init__(self, token: str | None) -> None:
self.token = token

def auth_flow(self, request: httpx.Request):
if self.token is not None:
request.headers["Authorization"] = f"Bearer {self.token}"
yield request


async def _call_whoami(asgi_app: Starlette, host: str, token: str | None) -> str:
auth = _MutableBearerAuth(token)
async with (
httpx.ASGITransport(asgi_app) as transport,
httpx.AsyncClient(
transport=transport,
base_url=f"http://{host}",
auth=auth,
timeout=httpx.Timeout(30, read=30),
follow_redirects=True,
) as http_client,
):
transport_ctx = streamable_http_client(f"http://{host}/mcp", http_client=http_client)
async with Client(transport_ctx) as client: # pragma: no branch
result = await client.call_tool("whoami", {})
assert isinstance(result.content[0], TextContent)
return result.content[0].text


@pytest.mark.anyio
async def test_get_access_token_reflects_current_request_in_stateful_session() -> None:
host = "testserver"

server = Server(
"auth-test-server",
on_call_tool=_handle_whoami,
on_list_tools=_handle_list_tools,
)

session_manager = StreamableHTTPSessionManager(app=server, stateless=False)

asgi_app = Starlette(
routes=[Mount("/mcp", app=session_manager.handle_request)],
middleware=[
Middleware(AuthenticationMiddleware, backend=BearerAuthBackend(_EchoTokenVerifier())),
Middleware(AuthContextMiddleware),
],
lifespan=lambda app: session_manager.run(),
)

async with asgi_app.router.lifespan_context(asgi_app):
assert await _call_whoami(asgi_app, host, "token-A") == "token-A"
assert await _call_whoami(asgi_app, host, "token-B") == "token-B"
assert await _call_whoami(asgi_app, host, None) == "<none>"
Loading