Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix StreamableHTTP transport API for backwards compatibility and clea…
…ner header handling

This commit addresses three issues identified in PR review:

1. **Restore RequestContext fields for backwards compatibility**
   - Re-add `headers` and `sse_read_timeout` fields as optional with None defaults
   - Mark them as deprecated in docstring since they're no longer used internally
   - Prevents breaking changes for any code accessing these fields

2. **Add runtime deprecation warnings for StreamableHTTPTransport constructor**
   - Use sentinel value pattern to detect when deprecated parameters are passed
   - Issue DeprecationWarning at runtime when headers, timeout, sse_read_timeout, or auth are provided
   - Complements existing @deprecated decorator for type checkers with actual runtime warnings
   - Improve deprecation message clarity

3. **Simplify header handling by removing redundant client parameter**
   - Remove `client` parameter from `_prepare_headers()` method
   - Stop extracting and re-passing client.headers since httpx automatically merges them
   - Only build MCP-specific headers (Accept, Content-Type, session headers)
   - httpx merges these with client.headers automatically, with our headers taking precedence
   - Reduces code complexity and eliminates unnecessary header extraction

The header handling change leverages httpx's built-in header merging behavior,
similar to how headers were handled before the refactoring but without the
redundant extraction-and-repass pattern.
  • Loading branch information
felixweinberger committed Dec 4, 2025
commit d64d191b6aaf59dbafb5f8d7015a77fb8f8069b5
85 changes: 50 additions & 35 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@
JSON = "application/json"
SSE = "text/event-stream"

# Sentinel value for detecting unset optional parameters
_UNSET = object()


class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
Expand All @@ -81,7 +84,8 @@ class RequestContext:
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
sse_read_timeout: float
headers: dict[str, str] | None = None # Deprecated - no longer used
sse_read_timeout: float | None = None # Deprecated - no longer used
Comment on lines +84 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also be breaking if anyone was directly instantiating these classes and relying on the config here.



class StreamableHTTPTransport:
Expand All @@ -90,8 +94,11 @@ class StreamableHTTPTransport:
@overload
def __init__(self, url: str) -> None: ...

@deprecated("Those parameters are deprecated. Use the url parameter instead.")
@overload
@deprecated(
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
"Configure these on the httpx.AsyncClient instead."
)
def __init__(
self,
url: str,
Expand All @@ -104,11 +111,10 @@ def __init__(
def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
auth: httpx.Auth | None = None,
**deprecated: dict[str, Any],
headers: Any = _UNSET,
timeout: Any = _UNSET,
sse_read_timeout: Any = _UNSET,
auth: Any = _UNSET,
) -> None:
"""Initialize the StreamableHTTP transport.

Expand All @@ -119,26 +125,40 @@ def __init__(
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
if deprecated:
warn(f"Deprecated parameters: {deprecated}", DeprecationWarning)
# Check for deprecated parameters and issue runtime warning
deprecated_params: list[str] = []
if headers is not _UNSET:
deprecated_params.append("headers")
if timeout is not _UNSET:
deprecated_params.append("timeout")
if sse_read_timeout is not _UNSET:
deprecated_params.append("sse_read_timeout")
if auth is not _UNSET:
deprecated_params.append("auth")

if deprecated_params:
warn(
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
"Configure these on the httpx.AsyncClient instead.",
DeprecationWarning,
stacklevel=2,
)

self.url = url
self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.auth = auth
self.session_id = None
self.protocol_version = None
self.request_headers = {
**self.headers,
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
}

def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID and protocol version if available."""
headers = base_headers.copy()

def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers.

These headers will be merged with the httpx.AsyncClient's default headers,
with these MCP-specific headers taking precedence.
"""
headers: dict[str, str] = {}
# Add MCP protocol headers
headers[ACCEPT] = f"{JSON}, {SSE}"
headers[CONTENT_TYPE] = JSON
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
Expand Down Expand Up @@ -242,7 +262,7 @@ async def handle_get_stream(
if not self.session_id:
return

headers = self._prepare_request_headers(self.request_headers)
headers = self._prepare_headers()
if last_event_id:
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover

Expand All @@ -251,7 +271,6 @@ async def handle_get_stream(
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
Expand Down Expand Up @@ -284,7 +303,7 @@ async def handle_get_stream(

async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._prepare_request_headers(ctx.headers)
headers = self._prepare_headers()
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
Expand All @@ -300,7 +319,6 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
Expand All @@ -318,7 +336,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_request_headers(ctx.headers)
headers = self._prepare_headers()
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

Expand Down Expand Up @@ -436,7 +454,7 @@ async def _handle_reconnection(
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
await anyio.sleep(delay_ms / 1000.0)

headers = self._prepare_request_headers(ctx.headers)
headers = self._prepare_headers()
headers[LAST_EVENT_ID] = last_event_id

# Extract original request ID to map responses
Expand All @@ -450,7 +468,6 @@ async def _handle_reconnection(
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.info("Reconnected to SSE stream")
Expand Down Expand Up @@ -538,12 +555,10 @@ async def post_writer(

ctx = RequestContext(
client=client,
headers=self.request_headers,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
sse_read_timeout=self.sse_read_timeout,
)

async def handle_request_async():
Expand All @@ -570,7 +585,7 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma:
return

try:
headers = self._prepare_request_headers(self.request_headers)
headers = self._prepare_headers()
response = await client.delete(self.url, headers=headers)

if response.status_code == 405:
Expand Down Expand Up @@ -678,8 +693,8 @@ def start_get_stream() -> None:
await write_stream.aclose()


@deprecated("Use `streamable_http_client` instead.")
@asynccontextmanager
@deprecated("Use `streamable_http_client` instead.")
async def streamablehttp_client(
url: str,
headers: dict[str, str] | None = None,
Expand Down
6 changes: 3 additions & 3 deletions tests/client/test_http_unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest

from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from mcp.client.streamable_http import streamable_http_client
from tests.test_helpers import wait_for_server

# Test constants with various Unicode characters
Expand Down Expand Up @@ -178,7 +178,7 @@ async def test_streamable_http_client_unicode_tool_call(running_unicode_server:
base_url = running_unicode_server
endpoint_url = f"{base_url}/mcp"

async with streamablehttp_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

Expand Down Expand Up @@ -210,7 +210,7 @@ async def test_streamable_http_client_unicode_prompts(running_unicode_server: st
base_url = running_unicode_server
endpoint_url = f"{base_url}/mcp"

async with streamablehttp_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
async with streamable_http_client(endpoint_url) as (read_stream, write_stream, _get_session_id):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()

Expand Down
40 changes: 39 additions & 1 deletion tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import socket
import time
from collections.abc import Generator
from datetime import timedelta
from typing import Any
from unittest.mock import MagicMock

Expand All @@ -25,7 +26,11 @@

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client
from mcp.client.streamable_http import (
StreamableHTTPTransport,
streamable_http_client,
streamablehttp_client, # pyright: ignore[reportDeprecated]
)
from mcp.server import Server
from mcp.server.streamable_http import (
MCP_PROTOCOL_VERSION_HEADER,
Expand Down Expand Up @@ -2356,3 +2361,36 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(

assert "content-type" in headers_data
assert headers_data["content-type"] == "application/json"


@pytest.mark.anyio
async def test_streamable_http_transport_deprecated_params_ignored(basic_server: None, basic_server_url: str) -> None:
"""Test that deprecated parameters passed to StreamableHTTPTransport are properly ignored."""
with pytest.warns(DeprecationWarning):
transport = StreamableHTTPTransport( # pyright: ignore[reportDeprecated]
url=f"{basic_server_url}/mcp",
headers={"X-Should-Be-Ignored": "ignored"},
timeout=999,
sse_read_timeout=timedelta(seconds=999),
auth=None,
)

headers = transport._prepare_headers()
assert "X-Should-Be-Ignored" not in headers
assert headers["accept"] == "application/json, text/event-stream"
assert headers["content-type"] == "application/json"


@pytest.mark.anyio
async def test_streamablehttp_client_deprecation_warning(basic_server: None, basic_server_url: str) -> None:
"""Test that the old streamablehttp_client() function issues a deprecation warning."""
with pytest.warns(DeprecationWarning, match="Use `streamable_http_client` instead"):
async with streamablehttp_client(f"{basic_server_url}/mcp") as ( # pyright: ignore[reportDeprecated]
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
tools = await session.list_tools()
assert len(tools.tools) > 0