Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,85 @@
import argparse
import logging
import sys
from functools import partial
from urllib.parse import urlparse
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("client")
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
logger.error("Error: %s", message)
return
logger.info("Received message from server: %s", message)
async def run_session(
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
client_info: types.Implementation | None = None,
):
async with ClientSession(
read_stream,
write_stream,
message_handler=message_handler,
client_info=client_info,
) as session:
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
env_dict = dict(env)
if urlparse(command_or_url).scheme in ("http", "https"):
# Use SSE client for HTTP(S) URLs
async with sse_client(command_or_url) as streams:
await run_session(*streams)
else:
# Use stdio client for commands
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
async with stdio_client(server_parameters) as streams:
await run_session(*streams)
def cli():
parser = argparse.ArgumentParser()
parser.add_argument("command_or_url", help="Command or URL to connect to")
parser.add_argument("args", nargs="*", help="Additional arguments")
parser.add_argument(
"-e",
"--env",
nargs=2,
action="append",
metavar=("KEY", "VALUE"),
help="Environment variables to set. Can be used multiple times.",
default=[],
)
args = parser.parse_args()
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,21 @@
"""
OAuth2 Authentication implementation for HTTPX.
Implements authorization code flow with PKCE and automatic token refresh.
"""
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
from mcp.client.auth.oauth2 import (
OAuthClientProvider,
PKCEParameters,
TokenStorage,
)
__all__ = [
"OAuthClientProvider",
"OAuthFlowError",
"OAuthRegistrationError",
"OAuthTokenError",
"PKCEParameters",
"TokenStorage",
]

View File

@@ -0,0 +1,10 @@
class OAuthFlowError(Exception):
"""Base exception for OAuth flow errors."""
class OAuthTokenError(OAuthFlowError):
"""Raised when token operations fail."""
class OAuthRegistrationError(OAuthFlowError):
"""Raised when client registration fails."""

View File

@@ -0,0 +1,487 @@
"""
OAuth client credential extensions for MCP.
Provides OAuth providers for machine-to-machine authentication flows:
- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret
- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication
(typically using a pre-built JWT from workload identity federation)
- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1)
"""
import time
from collections.abc import Awaitable, Callable
from typing import Any, Literal
from uuid import uuid4
import httpx
import jwt
from pydantic import BaseModel, Field
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
class ClientCredentialsOAuthProvider(OAuthClientProvider):
"""OAuth provider for client_credentials grant with client_id + client_secret.
This provider sets client_info directly, bypassing dynamic client registration.
Use this when you already have client credentials (client_id and client_secret).
Example:
```python
provider = ClientCredentialsOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
client_secret="my-client-secret",
)
```
"""
def __init__(
self,
server_url: str,
storage: TokenStorage,
client_id: str,
client_secret: str,
token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic",
scopes: str | None = None,
) -> None:
"""Initialize client_credentials OAuth provider.
Args:
server_url: The MCP server URL.
storage: Token storage implementation.
client_id: The OAuth client ID.
client_secret: The OAuth client secret.
token_endpoint_auth_method: Authentication method for token endpoint.
Either "client_secret_basic" (default) or "client_secret_post".
scopes: Optional space-separated list of scopes to request.
"""
# Build minimal client_metadata for the base class
client_metadata = OAuthClientMetadata(
redirect_uris=None,
grant_types=["client_credentials"],
token_endpoint_auth_method=token_endpoint_auth_method,
scope=scopes,
)
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
# Store client_info to be set during _initialize - no dynamic registration needed
self._fixed_client_info = OAuthClientInformationFull(
redirect_uris=None,
client_id=client_id,
client_secret=client_secret,
grant_types=["client_credentials"],
token_endpoint_auth_method=token_endpoint_auth_method,
scope=scopes,
)
async def _initialize(self) -> None:
"""Load stored tokens and set pre-configured client_info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = self._fixed_client_info
self._initialized = True
async def _perform_authorization(self) -> httpx.Request:
"""Perform client_credentials authorization."""
return await self._exchange_token_client_credentials()
async def _exchange_token_client_credentials(self) -> httpx.Request:
"""Build token exchange request for client_credentials grant."""
token_data: dict[str, Any] = {
"grant_type": "client_credentials",
}
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
# Use standard auth methods (client_secret_basic, client_secret_post, none)
token_data, headers = self.context.prepare_token_auth(token_data, headers)
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope:
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request("POST", token_url, data=token_data, headers=headers)
def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]:
"""Create an assertion provider that returns a static JWT token.
Use this when you have a pre-built JWT (e.g., from workload identity federation)
that doesn't need the audience parameter.
Example:
```python
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
)
```
Args:
token: The pre-built JWT assertion string.
Returns:
An async callback suitable for use as an assertion_provider.
"""
async def provider(audience: str) -> str:
return token
return provider
class SignedJWTParameters(BaseModel):
"""Parameters for creating SDK-signed JWT assertions.
Use `create_assertion_provider()` to create an assertion provider callback
for use with `PrivateKeyJWTOAuthProvider`.
Example:
```python
jwt_params = SignedJWTParameters(
issuer="my-client-id",
subject="my-client-id",
signing_key=private_key_pem,
)
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=jwt_params.create_assertion_provider(),
)
```
"""
issuer: str = Field(description="Issuer for JWT assertions (typically client_id).")
subject: str = Field(description="Subject identifier for JWT assertions (typically client_id).")
signing_key: str = Field(description="Private key for JWT signing (PEM format).")
signing_algorithm: str = Field(default="RS256", description="Algorithm for signing JWT assertions.")
lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.")
def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]:
"""Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider.
Returns:
An async callback that takes the audience (authorization server issuer URL)
and returns a signed JWT assertion.
"""
async def provider(audience: str) -> str:
now = int(time.time())
claims: dict[str, Any] = {
"iss": self.issuer,
"sub": self.subject,
"aud": audience,
"exp": now + self.lifetime_seconds,
"iat": now,
"jti": str(uuid4()),
}
if self.additional_claims:
claims.update(self.additional_claims)
return jwt.encode(claims, self.signing_key, algorithm=self.signing_algorithm)
return provider
class PrivateKeyJWTOAuthProvider(OAuthClientProvider):
"""OAuth provider for client_credentials grant with private_key_jwt authentication.
Uses RFC 7523 Section 2.2 for client authentication via JWT assertion.
The JWT assertion's audience MUST be the authorization server's issuer identifier
(per RFC 7523bis security updates). The `assertion_provider` callback receives
this audience value and must return a JWT with that audience.
**Option 1: Pre-built JWT via Workload Identity Federation**
In production scenarios, the JWT assertion is typically obtained from a workload
identity provider (e.g., GCP, AWS IAM, Azure AD):
```python
async def get_workload_identity_token(audience: str) -> str:
# Fetch JWT from your identity provider
# The JWT's audience must match the provided audience parameter
return await fetch_token_from_identity_provider(audience=audience)
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=get_workload_identity_token,
)
```
**Option 2: Static pre-built JWT**
If you have a static JWT that doesn't need the audience parameter:
```python
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
)
```
**Option 3: SDK-signed JWT (for testing/simple setups)**
For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`:
```python
jwt_params = SignedJWTParameters(
issuer="my-client-id",
subject="my-client-id",
signing_key=private_key_pem,
)
provider = PrivateKeyJWTOAuthProvider(
server_url="https://api.example.com",
storage=my_token_storage,
client_id="my-client-id",
assertion_provider=jwt_params.create_assertion_provider(),
)
```
"""
def __init__(
self,
server_url: str,
storage: TokenStorage,
client_id: str,
assertion_provider: Callable[[str], Awaitable[str]],
scopes: str | None = None,
) -> None:
"""Initialize private_key_jwt OAuth provider.
Args:
server_url: The MCP server URL.
storage: Token storage implementation.
client_id: The OAuth client ID.
assertion_provider: Async callback that takes the audience (authorization
server's issuer identifier) and returns a JWT assertion. Use
`SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs,
`static_assertion_provider()` for pre-built JWTs, or provide your own
callback for workload identity federation.
scopes: Optional space-separated list of scopes to request.
"""
# Build minimal client_metadata for the base class
client_metadata = OAuthClientMetadata(
redirect_uris=None,
grant_types=["client_credentials"],
token_endpoint_auth_method="private_key_jwt",
scope=scopes,
)
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
self._assertion_provider = assertion_provider
# Store client_info to be set during _initialize - no dynamic registration needed
self._fixed_client_info = OAuthClientInformationFull(
redirect_uris=None,
client_id=client_id,
grant_types=["client_credentials"],
token_endpoint_auth_method="private_key_jwt",
scope=scopes,
)
async def _initialize(self) -> None:
"""Load stored tokens and set pre-configured client_info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = self._fixed_client_info
self._initialized = True
async def _perform_authorization(self) -> httpx.Request:
"""Perform client_credentials authorization with private_key_jwt."""
return await self._exchange_token_client_credentials()
async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None:
"""Add JWT assertion for client authentication to token endpoint parameters."""
if not self.context.oauth_metadata:
raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover
# Audience MUST be the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01
audience = str(self.context.oauth_metadata.issuer)
assertion = await self._assertion_provider(audience)
# RFC 7523 Section 2.2: client authentication via JWT
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
async def _exchange_token_client_credentials(self) -> httpx.Request:
"""Build token exchange request for client_credentials grant with private_key_jwt."""
token_data: dict[str, Any] = {
"grant_type": "client_credentials",
}
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
# Add JWT client authentication (RFC 7523 Section 2.2)
await self._add_client_authentication_jwt(token_data=token_data)
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope:
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request("POST", token_url, data=token_data, headers=headers)
class JWTParameters(BaseModel):
"""JWT parameters."""
assertion: str | None = Field(
default=None,
description="JWT assertion for JWT authentication. "
"Will be used instead of generating a new assertion if provided.",
)
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
if self.assertion is not None:
# Prebuilt JWT (e.g. acquired out-of-band)
assertion = self.assertion
else:
if not self.jwt_signing_key:
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
if not self.issuer:
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
if not self.subject:
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
audience = self.audience if self.audience else with_audience_fallback
if not audience:
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
now = int(time.time())
claims: dict[str, Any] = {
"iss": self.issuer,
"sub": self.subject,
"aud": audience,
"exp": now + self.jwt_lifetime_seconds,
"iat": now,
"jti": str(uuid4()),
}
claims.update(self.claims or {})
assertion = jwt.encode(
claims,
self.jwt_signing_key,
algorithm=self.jwt_signing_algorithm or "RS256",
)
return assertion
class RFC7523OAuthClientProvider(OAuthClientProvider):
"""OAuth client provider for RFC 7523 jwt-bearer grant.
.. deprecated::
Use :class:`ClientCredentialsOAuthProvider` for client_credentials with
client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for
client_credentials with private_key_jwt authentication instead.
This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1)
where the JWT itself is the authorization grant.
"""
def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
jwt_parameters: JWTParameters | None = None,
) -> None:
import warnings
warnings.warn(
"RFC7523OAuthClientProvider is deprecated. Use ClientCredentialsOAuthProvider "
"or PrivateKeyJWTOAuthProvider instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
self.jwt_parameters = jwt_parameters
async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
) -> httpx.Request: # pragma: no cover
"""Build token exchange request for authorization_code flow."""
token_data = token_data or {}
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
self._add_client_authentication_jwt(token_data=token_data)
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
"""Perform the authorization flow."""
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
token_request = await self._exchange_token_jwt_bearer()
return token_request
else:
return await super()._perform_authorization()
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
"""Add JWT assertion for client authentication to token endpoint parameters."""
if not self.jwt_parameters:
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
if not self.context.oauth_metadata:
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
# We need to set the audience to the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
# We need to set the audience to the resource server, the audience is difference from the one in claims
# it represents the resource server that will validate the token
token_data["audience"] = self.context.get_resource_url()
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
"""Build token exchange request for JWT bearer grant."""
if not self.context.client_info:
raise OAuthFlowError("Missing client info") # pragma: no cover
if not self.jwt_parameters:
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
if not self.context.oauth_metadata:
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
# We need to set the audience to the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
token_data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
}
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope: # pragma: no branch
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request(
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)

View File

@@ -0,0 +1,616 @@
"""
OAuth2 Authentication implementation for HTTPX.
Implements authorization code flow with PKCE and automatic token refresh.
"""
import base64
import hashlib
import logging
import secrets
import string
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse
import anyio
import httpx
from pydantic import BaseModel, Field, ValidationError
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_client_info_from_metadata_url,
create_client_registration_request,
create_oauth_metadata_request,
extract_field_from_www_auth,
extract_resource_metadata_from_www_auth,
extract_scope_from_www_auth,
get_client_metadata_scopes,
handle_auth_metadata_response,
handle_protected_resource_response,
handle_registration_response,
handle_token_response_scopes,
is_valid_client_metadata_url,
should_use_client_metadata_url,
)
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.shared.auth_utils import (
calculate_token_expiry,
check_resource_allowed,
resource_url_from_server_url,
)
logger = logging.getLogger(__name__)
class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""
code_verifier: str = Field(..., min_length=43, max_length=128)
code_challenge: str = Field(..., min_length=43, max_length=128)
@classmethod
def generate(cls) -> "PKCEParameters":
"""Generate new PKCE parameters."""
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
class TokenStorage(Protocol):
"""Protocol for token storage implementations."""
async def get_tokens(self) -> OAuthToken | None:
"""Get stored tokens."""
...
async def set_tokens(self, tokens: OAuthToken) -> None:
"""Store tokens."""
...
async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Get stored client information."""
...
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information."""
...
@dataclass
class OAuthContext:
"""OAuth flow context."""
server_url: str
client_metadata: OAuthClientMetadata
storage: TokenStorage
redirect_handler: Callable[[str], Awaitable[None]] | None
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
timeout: float = 300.0
client_metadata_url: str | None = None
# Discovered metadata
protected_resource_metadata: ProtectedResourceMetadata | None = None
oauth_metadata: OAuthMetadata | None = None
auth_server_url: str | None = None
protocol_version: str | None = None
# Client registration
client_info: OAuthClientInformationFull | None = None
# Token management
current_tokens: OAuthToken | None = None
token_expiry_time: float | None = None
# State
lock: anyio.Lock = field(default_factory=anyio.Lock)
def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
parsed = urlparse(server_url)
return f"{parsed.scheme}://{parsed.netloc}"
def update_token_expiry(self, token: OAuthToken) -> None:
"""Update token expiry time using shared util function."""
self.token_expiry_time = calculate_token_expiry(token.expires_in)
def is_token_valid(self) -> bool:
"""Check if current token is valid."""
return bool(
self.current_tokens
and self.current_tokens.access_token
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
)
def can_refresh_token(self) -> bool:
"""Check if token can be refreshed."""
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
def clear_tokens(self) -> None:
"""Clear current tokens."""
self.current_tokens = None
self.token_expiry_time = None
def get_resource_url(self) -> str:
"""Get resource URL for RFC 8707.
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
"""
resource = resource_url_from_server_url(self.server_url)
# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource
return resource
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
"""Determine if the resource parameter should be included in OAuth requests.
Returns True if:
- Protected resource metadata is available, OR
- MCP-Protocol-Version header is 2025-06-18 or later
"""
# If we have protected resource metadata, include the resource param
if self.protected_resource_metadata is not None:
return True
# If no protocol version provided, don't include resource param
if not protocol_version:
return False
# Check if protocol version is 2025-06-18 or later
# Version format is YYYY-MM-DD, so string comparison works
return protocol_version >= "2025-06-18"
def prepare_token_auth(
self, data: dict[str, str], headers: dict[str, str] | None = None
) -> tuple[dict[str, str], dict[str, str]]:
"""Prepare authentication for token requests.
Args:
data: The form data to send
headers: Optional headers dict to update
Returns:
Tuple of (updated_data, updated_headers)
"""
if headers is None:
headers = {} # pragma: no cover
if not self.client_info:
return data, headers # pragma: no cover
auth_method = self.client_info.token_endpoint_auth_method
if auth_method == "client_secret_basic" and self.client_info.client_id and self.client_info.client_secret:
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
encoded_id = quote(self.client_info.client_id, safe="")
encoded_secret = quote(self.client_info.client_secret, safe="")
credentials = f"{encoded_id}:{encoded_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
# Don't include client_secret in body for basic auth
data = {k: v for k, v in data.items() if k != "client_secret"}
elif auth_method == "client_secret_post" and self.client_info.client_secret:
# Include client_secret in request body
data["client_secret"] = self.client_info.client_secret
# For auth_method == "none", don't add any client_secret
return data, headers
class OAuthClientProvider(httpx.Auth):
"""
OAuth2 authentication for httpx.
Handles OAuth flow with automatic client registration and token storage.
"""
requires_response_body = True
def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
client_metadata_url: str | None = None,
):
"""Initialize OAuth2 authentication.
Args:
server_url: The MCP server URL.
client_metadata: OAuth client metadata for registration.
storage: Token storage implementation.
redirect_handler: Handler for authorization redirects.
callback_handler: Handler for authorization callbacks.
timeout: Timeout for the OAuth flow.
client_metadata_url: URL-based client ID. When provided and the server
advertises client_id_metadata_document_supported=true, this URL will be
used as the client_id instead of performing dynamic client registration.
Must be a valid HTTPS URL with a non-root pathname.
Raises:
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
with a non-root pathname.
"""
# Validate client_metadata_url if provided
if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url):
raise ValueError(
f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}"
)
self.context = OAuthContext(
server_url=server_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
timeout=timeout,
client_metadata_url=client_metadata_url,
)
self._initialized = False
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""
Handle protected resource metadata discovery response.
Per SEP-985, supports fallback when discovery fails at one URL.
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True
except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
return False
elif response.status_code == 404: # pragma: no cover
# Not found - try next URL in fallback chain
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
return False
else:
# Other error - fail immediately
raise OAuthFlowError(
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover
async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
return token_request
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
"""Perform the authorization redirect and get auth code."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.redirect_handler:
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
if not self.context.callback_handler:
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
auth_endpoint = urljoin(auth_base_url, "/authorize")
if not self.context.client_info:
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
# Generate PKCE parameters
pkce_params = PKCEParameters.generate()
state = secrets.token_urlsafe(32)
auth_params = {
"response_type": "code",
"client_id": self.context.client_info.client_id,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"state": state,
"code_challenge": pkce_params.code_challenge,
"code_challenge_method": "S256",
}
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
if self.context.client_metadata.scope: # pragma: no branch
auth_params["scope"] = self.context.client_metadata.scope
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
await self.context.redirect_handler(authorization_url)
# Wait for callback
auth_code, returned_state = await self.context.callback_handler()
if returned_state is None or not secrets.compare_digest(returned_state, state):
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
if not auth_code:
raise OAuthFlowError("No authorization code received") # pragma: no cover
# Return auth code and code verifier for token exchange
return auth_code, pkce_params.code_verifier
def _get_token_endpoint(self) -> str:
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint)
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
return token_url
async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
) -> httpx.Request:
"""Build token exchange request for authorization_code flow."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.client_info:
raise OAuthFlowError("Missing client info") # pragma: no cover
token_url = self._get_token_endpoint()
token_data = token_data or {}
token_data.update(
{
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"client_id": self.context.client_info.client_id,
"code_verifier": code_verifier,
}
)
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url() # RFC 8707
# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
token_data, headers = self.context.prepare_token_auth(token_data, headers)
return httpx.Request("POST", token_url, data=token_data, headers=headers)
async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
if response.status_code != 200:
body = await response.aread() # pragma: no cover
body_text = body.decode("utf-8") # pragma: no cover
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
# Parse and validate response with scope validation
token_response = await handle_token_response_scopes(response)
# Store tokens in context
self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)
async def _refresh_token(self) -> httpx.Request:
"""Build token refresh request."""
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
raise OAuthTokenError("No refresh token available") # pragma: no cover
if not self.context.client_info or not self.context.client_info.client_id:
raise OAuthTokenError("No client info available") # pragma: no cover
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
refresh_data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": self.context.current_tokens.refresh_token,
"client_id": self.context.client_info.client_id,
}
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
"""Handle token refresh response. Returns True if successful."""
if response.status_code != 200:
logger.warning(f"Token refresh failed: {response.status_code}")
self.context.clear_tokens()
return False
try:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)
self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)
return True
except ValidationError:
logger.exception("Invalid refresh response")
self.context.clear_tokens()
return False
async def _initialize(self) -> None: # pragma: no cover
"""Load stored tokens and client info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = await self.context.storage.get_client_info()
self._initialized = True
def _add_auth_header(self, request: httpx.Request) -> None:
"""Add authorization header to request if we have valid tokens."""
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover
# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
# Refresh failed, need full re-authentication
self._initialized = False
if self.context.is_token_valid():
self._add_auth_header(request)
response = yield request
if response.status_code == 401:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)
for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_response = yield discovery_request # sending request
prm = await handle_protected_resource_response(discovery_response)
if prm:
self.context.protected_resource_metadata = prm
# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1
self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no cover
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
self.context.oauth_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")
# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response),
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)
# Step 4: Register client or use URL-based client ID (CIMD)
if not self.context.client_info:
if should_use_client_metadata_url(
self.context.oauth_metadata, self.context.client_metadata_url
):
# Use URL-based client ID (CIMD)
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
client_information = create_client_info_from_metadata_url(
self.context.client_metadata_url, # type: ignore[arg-type]
redirect_uris=self.context.client_metadata.redirect_uris,
)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
else:
# Fallback to Dynamic Client Registration
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
# Retry with new tokens
self._add_auth_header(request)
yield request
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")
# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope": # pragma: no branch
try:
# Step 2a: Update the required scopes
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
)
# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
# Retry with new tokens
self._add_auth_header(request)
yield request

View File

@@ -0,0 +1,336 @@
import logging
import re
from urllib.parse import urljoin, urlparse
from httpx import Request, Response
from pydantic import AnyUrl, ValidationError
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.types import LATEST_PROTOCOL_VERSION
logger = logging.getLogger(__name__)
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
"""
Extract field from WWW-Authenticate header.
Returns:
Field value if found in WWW-Authenticate header, None otherwise
"""
www_auth_header = response.headers.get("WWW-Authenticate")
if not www_auth_header:
return None
# Pattern matches: field_name="value" or field_name=value (unquoted)
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
match = re.search(pattern, www_auth_header)
if match:
# Return quoted value if present, otherwise unquoted value
return match.group(1) or match.group(2)
return None
def extract_scope_from_www_auth(response: Response) -> str | None:
"""
Extract scope parameter from WWW-Authenticate header as per RFC6750.
Returns:
Scope string if found in WWW-Authenticate header, None otherwise
"""
return extract_field_from_www_auth(response, "scope")
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
"""
if not response or response.status_code != 401:
return None # pragma: no cover
return extract_field_from_www_auth(response, "resource_metadata")
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
"""
Build ordered list of URLs to try for protected resource metadata discovery.
Per SEP-985, the client MUST:
1. Try resource_metadata from WWW-Authenticate header (if present)
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
Args:
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
server_url: server url
Returns:
Ordered list of URLs to try for discovery
"""
urls: list[str] = []
# Priority 1: WWW-Authenticate header with resource_metadata parameter
if www_auth_url:
urls.append(www_auth_url)
# Priority 2-3: Well-known URIs (RFC 9728)
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# Priority 2: Path-based well-known URI (if server has a path component)
if parsed.path and parsed.path != "/":
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
urls.append(path_based_url)
# Priority 3: Root-based well-known URI
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
urls.append(root_based_url)
return urls
def get_client_metadata_scopes(
www_authenticate_scope: str | None,
protected_resource_metadata: ProtectedResourceMetadata | None,
authorization_server_metadata: OAuthMetadata | None = None,
) -> str | None:
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
# Per MCP spec, scope selection priority order:
# 1. Use scope from WWW-Authenticate header (if provided)
# 2. Use all scopes from PRM scopes_supported (if available)
# 3. Omit scope parameter if neither is available
if www_authenticate_scope is not None:
# Priority 1: WWW-Authenticate header scope
return www_authenticate_scope
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
# Priority 2: PRM scopes_supported
return " ".join(protected_resource_metadata.scopes_supported)
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
else:
# Priority 3: Omit scope parameter
return None
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Generate ordered list of (url, type) tuples for discovery attempts.
Args:
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
"""
if not auth_server_url:
# Legacy path using the 2025-03-26 spec:
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
parsed = urlparse(server_url)
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
urls: list[str] = []
parsed = urlparse(auth_server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# RFC 8414: Path-aware OAuth discovery
if parsed.path and parsed.path != "/":
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oauth_path))
# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oidc_path))
# https://openid.net/specs/openid-connect-discovery-1_0.html
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
urls.append(urljoin(base_url, oidc_path))
return urls
# OAuth root
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
# https://openid.net/specs/openid-connect-discovery-1_0.html
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
return urls
async def handle_protected_resource_response(
response: Response,
) -> ProtectedResourceMetadata | None:
"""
Handle protected resource metadata discovery response.
Per SEP-985, supports fallback when discovery fails at one URL.
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
return metadata
except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
return None
else:
# Not found - try next URL in fallback chain
return None
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
if response.status_code == 200:
try:
content = await response.aread()
asm = OAuthMetadata.model_validate_json(content)
return True, asm
except ValidationError: # pragma: no cover
return True, None
elif response.status_code < 400 or response.status_code >= 500:
return False, None # Non-4XX error, stop trying
return True, None
def create_oauth_metadata_request(url: str) -> Request:
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
def create_client_registration_request(
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
) -> Request:
"""Build registration request or skip if already registered."""
if auth_server_metadata and auth_server_metadata.registration_endpoint:
registration_url = str(auth_server_metadata.registration_endpoint)
else:
registration_url = urljoin(auth_base_url, "/register")
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
"""Handle registration response."""
if response.status_code not in (200, 201):
await response.aread()
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
try:
content = await response.aread()
client_info = OAuthClientInformationFull.model_validate_json(content)
return client_info
# self.context.client_info = client_info
# await self.context.storage.set_client_info(client_info)
except ValidationError as e: # pragma: no cover
raise OAuthRegistrationError(f"Invalid registration response: {e}")
def is_valid_client_metadata_url(url: str | None) -> bool:
"""Validate that a URL is suitable for use as a client_id (CIMD).
The URL must be HTTPS with a non-root pathname.
Args:
url: The URL to validate
Returns:
True if the URL is a valid HTTPS URL with a non-root pathname
"""
if not url:
return False
try:
parsed = urlparse(url)
return parsed.scheme == "https" and parsed.path not in ("", "/")
except Exception:
return False
def should_use_client_metadata_url(
oauth_metadata: OAuthMetadata | None,
client_metadata_url: str | None,
) -> bool:
"""Determine if URL-based client ID (CIMD) should be used instead of DCR.
URL-based client IDs should be used when:
1. The server advertises client_id_metadata_document_supported=true
2. The client has a valid client_metadata_url configured
Args:
oauth_metadata: OAuth authorization server metadata
client_metadata_url: URL-based client ID (already validated)
Returns:
True if CIMD should be used, False if DCR should be used
"""
if not client_metadata_url:
return False
if not oauth_metadata:
return False
return oauth_metadata.client_id_metadata_document_supported is True
def create_client_info_from_metadata_url(
client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None
) -> OAuthClientInformationFull:
"""Create client information using a URL-based client ID (CIMD).
When using URL-based client IDs, the URL itself becomes the client_id
and no client_secret is used (token_endpoint_auth_method="none").
Args:
client_metadata_url: The URL to use as the client_id
redirect_uris: The redirect URIs from the client metadata (passed through for
compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata)
Returns:
OAuthClientInformationFull with the URL as client_id
"""
return OAuthClientInformationFull(
client_id=client_metadata_url,
token_endpoint_auth_method="none",
redirect_uris=redirect_uris,
)
async def handle_token_response_scopes(
response: Response,
) -> OAuthToken:
"""Parse and validate token response with optional scope validation.
Parses token response JSON. Callers should check response.status_code before calling.
Args:
response: HTTP response from token endpoint (status already checked by caller)
Returns:
Validated OAuthToken model
Raises:
OAuthTokenError: If response JSON is invalid
"""
try:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)
return token_response
except ValidationError as e: # pragma: no cover
raise OAuthTokenError(f"Invalid token response: {e}")

View File

@@ -0,0 +1,9 @@
"""
Experimental client features.
WARNING: These APIs are experimental and may change without notice.
"""
from mcp.client.experimental.tasks import ExperimentalClientFeatures
__all__ = ["ExperimentalClientFeatures"]

View File

@@ -0,0 +1,290 @@
"""
Experimental task handler protocols for server -> client requests.
This module provides Protocol types and default handlers for when servers
send task-related requests to clients (the reverse of normal client -> server flow).
WARNING: These APIs are experimental and may change without notice.
Use cases:
- Server sends task-augmented sampling/elicitation request to client
- Client creates a local task, spawns background work, returns CreateTaskResult
- Server polls client's task status via tasks/get, tasks/result, etc.
"""
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol
from pydantic import TypeAdapter
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
if TYPE_CHECKING:
from mcp.client.session import ClientSession
class GetTaskHandlerFnT(Protocol):
"""Handler for tasks/get requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
class GetTaskResultHandlerFnT(Protocol):
"""Handler for tasks/result requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
class ListTasksHandlerFnT(Protocol):
"""Handler for tasks/list requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
class CancelTaskHandlerFnT(Protocol):
"""Handler for tasks/cancel requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
class TaskAugmentedSamplingFnT(Protocol):
"""Handler for task-augmented sampling/createMessage requests from server.
When server sends a CreateMessageRequest with task field, this callback
is invoked. The callback should create a task, spawn background work,
and return CreateTaskResult immediately.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
class TaskAugmentedElicitationFnT(Protocol):
"""Handler for task-augmented elicitation/create requests from server.
When server sends an ElicitRequest with task field, this callback
is invoked. The callback should create a task, spawn background work,
and return CreateTaskResult immediately.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
async def default_get_task_handler(
context: RequestContext["ClientSession", Any],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/get not supported",
)
async def default_get_task_result_handler(
context: RequestContext["ClientSession", Any],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/result not supported",
)
async def default_list_tasks_handler(
context: RequestContext["ClientSession", Any],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/list not supported",
)
async def default_cancel_task_handler(
context: RequestContext["ClientSession", Any],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/cancel not supported",
)
async def default_task_augmented_sampling(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Task-augmented sampling not supported",
)
async def default_task_augmented_elicitation(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Task-augmented elicitation not supported",
)
@dataclass
class ExperimentalTaskHandlers:
"""Container for experimental task handlers.
Groups all task-related handlers that handle server -> client requests.
This includes both pure task requests (get, list, cancel, result) and
task-augmented request handlers (sampling, elicitation with task field).
WARNING: These APIs are experimental and may change without notice.
Example:
handlers = ExperimentalTaskHandlers(
get_task=my_get_task_handler,
list_tasks=my_list_tasks_handler,
)
session = ClientSession(..., experimental_task_handlers=handlers)
"""
# Pure task request handlers
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
# Task-augmented request handlers
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
def build_capability(self) -> types.ClientTasksCapability | None:
"""Build ClientTasksCapability from the configured handlers.
Returns a capability object that reflects which handlers are configured
(i.e., not using the default "not supported" handlers).
Returns:
ClientTasksCapability if any handlers are provided, None otherwise
"""
has_list = self.list_tasks is not default_list_tasks_handler
has_cancel = self.cancel_task is not default_cancel_task_handler
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
# If no handlers are provided, return None
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
return None
# Build requests capability if any request handlers are provided
requests_capability: types.ClientTasksRequestsCapability | None = None
if has_sampling or has_elicitation:
requests_capability = types.ClientTasksRequestsCapability(
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
if has_sampling
else None,
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
if has_elicitation
else None,
)
return types.ClientTasksCapability(
list=types.TasksListCapability() if has_list else None,
cancel=types.TasksCancelCapability() if has_cancel else None,
requests=requests_capability,
)
@staticmethod
def handles_request(request: types.ServerRequest) -> bool:
"""Check if this handler handles the given request type."""
return isinstance(
request.root,
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
)
async def handle_request(
self,
ctx: RequestContext["ClientSession", Any],
responder: RequestResponder[types.ServerRequest, types.ClientResult],
) -> None:
"""Handle a task-related request from the server.
Call handles_request() first to check if this handler can handle the request.
"""
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
match responder.request.root:
case types.GetTaskRequest(params=params):
response = await self.get_task(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.GetTaskPayloadRequest(params=params):
response = await self.get_task_result(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.ListTasksRequest(params=params):
response = await self.list_tasks(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.CancelTaskRequest(params=params):
response = await self.cancel_task(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case _: # pragma: no cover
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
# Backwards compatibility aliases
default_task_augmented_sampling_callback = default_task_augmented_sampling
default_task_augmented_elicitation_callback = default_task_augmented_elicitation

View File

@@ -0,0 +1,224 @@
"""
Experimental client-side task support.
This module provides client methods for interacting with MCP tasks.
WARNING: These APIs are experimental and may change without notice.
Example:
# Call a tool as a task
result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"})
task_id = result.task.taskId
# Get task status
status = await session.experimental.get_task(task_id)
# Get task result when complete
if status.status == "completed":
result = await session.experimental.get_task_result(task_id, CallToolResult)
# List all tasks
tasks = await session.experimental.list_tasks()
# Cancel a task
await session.experimental.cancel_task(task_id)
"""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypeVar
import mcp.types as types
from mcp.shared.experimental.tasks.polling import poll_until_terminal
if TYPE_CHECKING:
from mcp.client.session import ClientSession
ResultT = TypeVar("ResultT", bound=types.Result)
class ExperimentalClientFeatures:
"""
Experimental client features for tasks and other experimental APIs.
WARNING: These APIs are experimental and may change without notice.
Access via session.experimental:
status = await session.experimental.get_task(task_id)
"""
def __init__(self, session: "ClientSession") -> None:
self._session = session
async def call_tool_as_task(
self,
name: str,
arguments: dict[str, Any] | None = None,
*,
ttl: int = 60000,
meta: dict[str, Any] | None = None,
) -> types.CreateTaskResult:
"""Call a tool as a task, returning a CreateTaskResult for polling.
This is a convenience method for calling tools that support task execution.
The server will return a task reference instead of the immediate result,
which can then be polled via `get_task()` and retrieved via `get_task_result()`.
Args:
name: The tool name
arguments: Tool arguments
ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute)
meta: Optional metadata to include in the request
Returns:
CreateTaskResult containing the task reference
Example:
# Create task
result = await session.experimental.call_tool_as_task(
"long_running_tool", {"input": "data"}
)
task_id = result.task.taskId
# Poll for completion
while True:
status = await session.experimental.get_task(task_id)
if status.status == "completed":
break
await asyncio.sleep(0.5)
# Get result
final = await session.experimental.get_task_result(task_id, CallToolResult)
"""
_meta: types.RequestParams.Meta | None = None
if meta is not None:
_meta = types.RequestParams.Meta(**meta)
return await self._session.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
task=types.TaskMetadata(ttl=ttl),
_meta=_meta,
),
)
),
types.CreateTaskResult,
)
async def get_task(self, task_id: str) -> types.GetTaskResult:
"""
Get the current status of a task.
Args:
task_id: The task identifier
Returns:
GetTaskResult containing the task status and metadata
"""
return await self._session.send_request(
types.ClientRequest(
types.GetTaskRequest(
params=types.GetTaskRequestParams(taskId=task_id),
)
),
types.GetTaskResult,
)
async def get_task_result(
self,
task_id: str,
result_type: type[ResultT],
) -> ResultT:
"""
Get the result of a completed task.
The result type depends on the original request type:
- tools/call tasks return CallToolResult
- Other request types return their corresponding result type
Args:
task_id: The task identifier
result_type: The expected result type (e.g., CallToolResult)
Returns:
The task result, validated against result_type
"""
return await self._session.send_request(
types.ClientRequest(
types.GetTaskPayloadRequest(
params=types.GetTaskPayloadRequestParams(taskId=task_id),
)
),
result_type,
)
async def list_tasks(
self,
cursor: str | None = None,
) -> types.ListTasksResult:
"""
List all tasks.
Args:
cursor: Optional pagination cursor
Returns:
ListTasksResult containing tasks and optional next cursor
"""
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
return await self._session.send_request(
types.ClientRequest(
types.ListTasksRequest(params=params),
),
types.ListTasksResult,
)
async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
"""
Cancel a running task.
Args:
task_id: The task identifier
Returns:
CancelTaskResult with the updated task state
"""
return await self._session.send_request(
types.ClientRequest(
types.CancelTaskRequest(
params=types.CancelTaskRequestParams(taskId=task_id),
)
),
types.CancelTaskResult,
)
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
"""
Poll a task until it reaches a terminal status.
Yields GetTaskResult for each poll, allowing the caller to react to
status changes (e.g., handle input_required). Exits when task reaches
a terminal status (completed, failed, cancelled).
Respects the pollInterval hint from the server.
Args:
task_id: The task identifier
Yields:
GetTaskResult for each poll
Example:
async for status in session.experimental.poll_task(task_id):
print(f"Status: {status.status}")
if status.status == "input_required":
# Handle elicitation request via tasks/result
pass
# Task is now terminal, get the result
result = await session.experimental.get_task_result(task_id, CallToolResult)
"""
async for status in poll_until_terminal(self.get_task, task_id):
yield status

View File

@@ -0,0 +1,615 @@
import logging
from datetime import timedelta
from typing import Any, Protocol, overload
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl, TypeAdapter
from typing_extensions import deprecated
import mcp.types as types
from mcp.client.experimental import ExperimentalClientFeatures
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
logger = logging.getLogger("client")
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ... # pragma: no branch
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ... # pragma: no branch
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
*,
sampling_capabilities: types.SamplingCapability | None = None,
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._sampling_capabilities = sampling_capabilities
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
self._experimental_features: ExperimentalClientFeatures | None = None
# Experimental: Task handlers (use defaults if not provided)
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
async def initialize(self) -> types.InitializeResult:
sampling = (
(self._sampling_capabilities or types.SamplingCapability())
if self._sampling_callback is not _default_sampling_callback
else None
)
elicitation = (
types.ElicitationCapability(
form=types.FormElicitationCapability(),
url=types.UrlElicitationCapability(),
)
if self._elicitation_callback is not _default_elicitation_callback
else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
types.RootsCapability(listChanged=True)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
tasks=self._task_handlers.build_capability(),
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self._server_capabilities = result.capabilities
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
return result
def get_server_capabilities(self) -> types.ServerCapabilities | None:
"""Return the server capabilities received during initialization.
Returns None if the session has not been initialized yet.
"""
return self._server_capabilities
@property
def experimental(self) -> ExperimentalClientFeatures:
"""Experimental APIs for tasks and other features.
WARNING: These APIs are experimental and may change without notice.
Example:
status = await session.experimental.get_task(task_id)
result = await session.experimental.get_task_result(task_id, CallToolResult)
"""
if self._experimental_features is None:
self._experimental_features = ExperimentalClientFeatures(self)
return self._experimental_features
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
types.ClientRequest(types.PingRequest()),
types.EmptyResult,
)
async def send_progress_notification(
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ClientNotification(
types.ProgressNotification(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
),
)
)
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SetLevelRequest(
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
@overload
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self) -> types.ListResourcesResult: ...
async def list_resources(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourcesResult:
"""Send a resources/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
types.ListResourcesResult,
)
@overload
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(
self, *, params: types.PaginatedRequestParams | None
) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
async def list_resource_templates(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
types.ListResourceTemplatesResult,
)
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SubscribeRequest(
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.UnsubscribeRequest(
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""
_meta: types.RequestParams.Meta | None = None
if meta is not None:
_meta = types.RequestParams.Meta(**meta)
result = await self.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)
if not result.isError:
await self._validate_tool_result(name, result)
return result
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
"""Validate the structured content of a tool result against its output schema."""
if name not in self._tool_output_schemas:
# refresh output schema cache
await self.list_tools()
output_schema = None
if name in self._tool_output_schemas:
output_schema = self._tool_output_schemas.get(name)
else:
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
if output_schema is not None:
from jsonschema import SchemaError, ValidationError, validate
if result.structuredContent is None:
raise RuntimeError(
f"Tool {name} has an output schema but did not return structured content"
) # pragma: no cover
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
except SchemaError as e: # pragma: no cover
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
@overload
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self) -> types.ListPromptsResult: ...
async def list_prompts(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListPromptsResult:
"""Send a prompts/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
types.ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return await self.send_request(
types.ClientRequest(
types.GetPromptRequest(
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
async def complete(
self,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
context_arguments: dict[str, str] | None = None,
) -> types.CompleteResult:
"""Send a completion/complete request."""
context = None
if context_arguments is not None:
context = types.CompletionContext(arguments=context_arguments)
return await self.send_request(
types.ClientRequest(
types.CompleteRequest(
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
context=context,
),
)
),
types.CompleteResult,
)
@overload
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self) -> types.ListToolsResult: ...
async def list_tools(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListToolsResult:
"""Send a tools/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
result = await self.send_request(
types.ClientRequest(types.ListToolsRequest(params=request_params)),
types.ListToolsResult,
)
# Cache tool output schemas for future validation
# Note: don't clear the cache, as we may be using a cursor
for tool in result.tools:
self._tool_output_schemas[tool.name] = tool.outputSchema
return result
async def send_roots_list_changed(self) -> None: # pragma: no cover
"""Send a roots/list_changed notification."""
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
# Delegate to experimental task handler if applicable
if self._task_handlers.handles_request(responder.request):
with responder:
await self._task_handlers.handle_request(ctx, responder)
return None
# Core request handling
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
# Check if this is a task-augmented request
if params.task is not None:
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
else:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ElicitRequest(params=params):
with responder:
# Check if this is a task-augmented request
if params.task is not None:
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
else:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest(): # pragma: no cover
with responder:
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
case _: # pragma: no cover
pass # Task requests handled above by _task_handlers
return None
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case types.ElicitCompleteNotification(params=params):
# Handle elicitation completion notification
# Clients MAY use this to retry requests or update UI
# The notification contains the elicitationId of the completed elicitation
pass
case _:
pass

View File

@@ -0,0 +1,447 @@
"""
SessionGroup concurrently manages multiple MCP session connections.
Tools, resources, and prompts are aggregated across servers. Servers may
be connected to or disconnected from at any point after initialization.
This abstractions can handle naming collisions using a custom user-provided
hook.
"""
import contextlib
import logging
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias, overload
import anyio
import httpx
from pydantic import BaseModel
from typing_extensions import Self, deprecated
import mcp
from mcp import types
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamable_http_client
from mcp.shared._httpx_utils import create_mcp_http_client
from mcp.shared.exceptions import McpError
from mcp.shared.session import ProgressFnT
class SseServerParameters(BaseModel):
"""Parameters for intializing a sse_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: float = 5
# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
class StreamableHttpParameters(BaseModel):
"""Parameters for intializing a streamable_http_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Close the client session when the transport closes.
terminate_on_close: bool = True
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
# Use dataclass instead of pydantic BaseModel
# because pydantic BaseModel cannot handle Protocol fields.
@dataclass
class ClientSessionParameters:
"""Parameters for establishing a client session to an MCP server."""
read_timeout_seconds: timedelta | None = None
sampling_callback: SamplingFnT | None = None
elicitation_callback: ElicitationFnT | None = None
list_roots_callback: ListRootsFnT | None = None
logging_callback: LoggingFnT | None = None
message_handler: MessageHandlerFnT | None = None
client_info: types.Implementation | None = None
class ClientSessionGroup:
"""Client for managing connections to multiple MCP servers.
This class is responsible for encapsulating management of server connections.
It aggregates tools, resources, and prompts from all connected servers.
For auxiliary handlers, such as resource subscription, this is delegated to
the client and can be accessed via the session.
Example Usage:
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
async with ClientSessionGroup(component_name_hook=name_fn) as group:
for server_param in server_params:
await group.connect_to_server(server_param)
...
"""
class _ComponentNames(BaseModel):
"""Used for reverse index to find components."""
prompts: set[str] = set()
resources: set[str] = set()
tools: set[str] = set()
# Standard MCP components.
_prompts: dict[str, types.Prompt]
_resources: dict[str, types.Resource]
_tools: dict[str, types.Tool]
# Client-server connection management.
_sessions: dict[mcp.ClientSession, _ComponentNames]
_tool_to_session: dict[str, mcp.ClientSession]
_exit_stack: contextlib.AsyncExitStack
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
# Optional fn consuming (component_name, serverInfo) for custom names.
# This is provide a means to mitigate naming conflicts across servers.
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
_component_name_hook: _ComponentNameHook | None
def __init__(
self,
exit_stack: contextlib.AsyncExitStack | None = None,
component_name_hook: _ComponentNameHook | None = None,
) -> None:
"""Initializes the MCP client."""
self._tools = {}
self._resources = {}
self._prompts = {}
self._sessions = {}
self._tool_to_session = {}
if exit_stack is None:
self._exit_stack = contextlib.AsyncExitStack()
self._owns_exit_stack = True
else:
self._exit_stack = exit_stack
self._owns_exit_stack = False
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook
async def __aenter__(self) -> Self: # pragma: no cover
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
await self._exit_stack.__aenter__()
return self
async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None: # pragma: no cover
"""Closes session exit stacks and main exit stack upon completion."""
# Only close the main exit stack if we created it
if self._owns_exit_stack:
await self._exit_stack.aclose()
# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
return list(self._sessions.keys()) # pragma: no cover
@property
def prompts(self) -> dict[str, types.Prompt]:
"""Returns the prompts as a dictionary of names to prompts."""
return self._prompts
@property
def resources(self) -> dict[str, types.Resource]:
"""Returns the resources as a dictionary of names to resources."""
return self._resources
@property
def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools
@overload
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...
@overload
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
async def call_tool(
self,
name: str,
*,
args: dict[str, Any],
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult: ...
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
args: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
session_tool_name = self.tools[name].name
return await session.call_tool(
session_tool_name,
arguments if args is None else args,
read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
meta=meta,
)
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
"""Disconnects from a single MCP server."""
session_known_for_components = session in self._sessions
session_known_for_stack = session in self._session_exit_stacks
if not session_known_for_components and not session_known_for_stack:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message="Provided session is not managed or already disconnected.",
)
)
if session_known_for_components: # pragma: no cover
component_names = self._sessions.pop(session) # Pop from _sessions tracking
# Remove prompts associated with the session.
for name in component_names.prompts:
if name in self._prompts:
del self._prompts[name]
# Remove resources associated with the session.
for name in component_names.resources:
if name in self._resources:
del self._resources[name]
# Remove tools associated with the session.
for name in component_names.tools:
if name in self._tools:
del self._tools[name]
if name in self._tool_to_session:
del self._tool_to_session[name]
# Clean up the session's resources via its dedicated exit stack
if session_known_for_stack:
session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover
await session_stack_to_close.aclose() # pragma: no cover
async def connect_with_session(
self, server_info: types.Implementation, session: mcp.ClientSession
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
await self._aggregate_components(server_info, session)
return session
async def connect_to_server(
self,
server_params: ServerParameters,
session_params: ClientSessionParameters | None = None,
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
return await self.connect_with_session(server_info, session)
async def _establish_session(
self,
server_params: ServerParameters,
session_params: ClientSessionParameters,
) -> tuple[types.Implementation, mcp.ClientSession]:
"""Establish a client session to an MCP server."""
session_stack = contextlib.AsyncExitStack()
try:
# Create read and write streams that facilitate io with the server.
if isinstance(server_params, StdioServerParameters):
client = mcp.stdio_client(server_params)
read, write = await session_stack.enter_async_context(client)
elif isinstance(server_params, SseServerParameters):
client = sse_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
)
read, write = await session_stack.enter_async_context(client)
else:
httpx_client = create_mcp_http_client(
headers=server_params.headers,
timeout=httpx.Timeout(
server_params.timeout.total_seconds(),
read=server_params.sse_read_timeout.total_seconds(),
),
)
await session_stack.enter_async_context(httpx_client)
client = streamable_http_client(
url=server_params.url,
http_client=httpx_client,
terminate_on_close=server_params.terminate_on_close,
)
read, write, _ = await session_stack.enter_async_context(client)
session = await session_stack.enter_async_context(
mcp.ClientSession(
read,
write,
read_timeout_seconds=session_params.read_timeout_seconds,
sampling_callback=session_params.sampling_callback,
elicitation_callback=session_params.elicitation_callback,
list_roots_callback=session_params.list_roots_callback,
logging_callback=session_params.logging_callback,
message_handler=session_params.message_handler,
client_info=session_params.client_info,
)
)
result = await session.initialize()
# Session successfully initialized.
# Store its stack and register the stack with the main group stack.
self._session_exit_stacks[session] = session_stack
# session_stack itself becomes a resource managed by the
# main _exit_stack.
await self._exit_stack.enter_async_context(session_stack)
return result.serverInfo, session
except Exception: # pragma: no cover
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
raise
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
"""Aggregates prompts, resources, and tools from a given session."""
# Create a reverse index so we can find all prompts, resources, and
# tools belonging to this session. Used for removing components from
# the session group via self.disconnect_from_server.
component_names = self._ComponentNames()
# Temporary components dicts. We do not want to modify the aggregate
# lists in case of an intermediate failure.
prompts_temp: dict[str, types.Prompt] = {}
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")
# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")
# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")
# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
if not any((prompts_temp, resources_temp, tools_temp)):
del self._session_exit_stacks[session] # pragma: no cover
# Check for duplicates.
matching_prompts = prompts_temp.keys() & self._prompts.keys()
if matching_prompts:
raise McpError( # pragma: no cover
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_prompts} already exist in group prompts.",
)
)
matching_resources = resources_temp.keys() & self._resources.keys()
if matching_resources:
raise McpError( # pragma: no cover
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_resources} already exist in group resources.",
)
)
matching_tools = tools_temp.keys() & self._tools.keys()
if matching_tools:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_tools} already exist in group tools.",
)
)
# Aggregate components.
self._sessions[session] = component_names
self._prompts.update(prompts_temp)
self._resources.update(resources_temp)
self._tools.update(tools_temp)
self._tool_to_session.update(tool_to_session_temp)
def _component_name(self, name: str, server_info: types.Implementation) -> str:
if self._component_name_hook:
return self._component_name_hook(name, server_info)
return name

View File

@@ -0,0 +1,164 @@
import logging
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError
import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
query_params = parse_qs(urlparse(endpoint_url).query)
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
on_session_created: Callable[[str], None] | None = None,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
on_session_created: Optional callback invoked with the session ID when received.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover
if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)
task_status.started(endpoint_url)
case "message":
# Skip empty data (keep-alive pings)
if not sse.data:
continue
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: no cover
logger.exception("Encountered SSE exception") # pragma: no cover
raise sse_exc # pragma: no cover
except Exception as exc: # pragma: no cover
logger.exception("Error in sse_reader") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: no cover
logger.exception("Error in post_writer") # pragma: no cover
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

View File

@@ -0,0 +1,278 @@
import logging
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal, TextIO
import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
import mcp.types as types
from mcp.os.posix.utilities import terminate_posix_process_tree
from mcp.os.win32.utilities import (
FallbackProcess,
create_windows_process,
get_windows_executable_command,
terminate_windows_process_tree,
)
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
[
"APPDATA",
"HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PATHEXT",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
)
# Timeout for process termination before falling back to force kill
PROCESS_TERMINATION_TIMEOUT = 2.0
def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
safe to inherit.
"""
env: dict[str, str] = {}
for key in DEFAULT_INHERITED_ENV_VARS:
value = os.environ.get(key)
if value is None:
continue # pragma: no cover
if value.startswith("()"): # pragma: no cover
# Skip functions, which are a security risk
continue # pragma: no cover
env[key] = value
return env
class StdioServerParameters(BaseModel):
command: str
"""The executable to run to start the server."""
args: list[str] = Field(default_factory=list)
"""Command line arguments to pass to the executable."""
env: dict[str, str] | None = None
"""
The environment to use when spawning the process.
If not specified, the result of get_default_environment() will be used.
"""
cwd: str | Path | None = None
"""The working directory to use when spawning the process."""
encoding: str = "utf-8"
"""
The text encoding used when sending/receiving messages to the server
defaults to utf-8
"""
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
"""
The text encoding error handler.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""
@asynccontextmanager
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
"""
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
try:
command = _get_executable_command(server.command)
# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise
async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
try:
async with read_stream_writer:
buffer = ""
async for chunk in TextReceiveStream(
process.stdout,
encoding=server.encoding,
errors=server.encoding_error_handler,
):
lines = (buffer + chunk).split("\n")
buffer = lines.pop()
for line in lines:
try:
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc: # pragma: no cover
logger.exception("Failed to parse JSONRPC message from server")
await read_stream_writer.send(exc)
continue
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()
async def stdin_writer():
assert process.stdin, "Opened process is missing stdin"
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send(
(json + "\n").encode(
encoding=server.encoding,
errors=server.encoding_error_handler,
)
)
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()
async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
yield read_stream, write_stream
finally:
# MCP spec: stdio shutdown sequence
# 1. Close input stream to server
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
# 3. Send SIGKILL if still not exited
if process.stdin: # pragma: no branch
try:
await process.stdin.aclose()
except Exception: # pragma: no cover
# stdin might already be closed, which is fine
pass
try:
# Give the process time to exit gracefully after stdin closes
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
await process.wait()
except TimeoutError:
# Process didn't exit from stdin closure, use platform-specific termination
# which handles SIGTERM -> SIGKILL escalation
await _terminate_process_tree(process)
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
def _get_executable_command(command: str) -> str:
"""
Get the correct executable command normalized for the current platform.
Args:
command: Base command (e.g., 'uvx', 'npx')
Returns:
str: Platform-appropriate command
"""
if sys.platform == "win32": # pragma: no cover
return get_windows_executable_command(command)
else:
return command # pragma: no cover
async def _create_platform_compatible_process(
command: str,
args: list[str],
env: dict[str, str] | None = None,
errlog: TextIO = sys.stderr,
cwd: Path | str | None = None,
):
"""
Creates a subprocess in a platform-compatible way.
Unix: Creates process in a new session/process group for killpg support
Windows: Creates process in a Job Object for reliable child termination
"""
if sys.platform == "win32": # pragma: no cover
process = await create_windows_process(command, args, env, errlog, cwd)
else:
process = await anyio.open_process(
[command, *args],
env=env,
stderr=errlog,
cwd=cwd,
start_new_session=True,
) # pragma: no cover
return process
async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
"""
Terminate a process and all its children using platform-specific methods.
Unix: Uses os.killpg() for atomic process group termination
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
Args:
process: The process to terminate
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
"""
if sys.platform == "win32": # pragma: no cover
await terminate_windows_process_tree(process, timeout_seconds)
else: # pragma: no cover
# FallbackProcess should only be used for Windows compatibility
assert isinstance(process, Process)
await terminate_posix_process_tree(process, timeout_seconds)

View File

@@ -0,0 +1,722 @@
"""
StreamableHTTP Client Transport Module
This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
import contextlib
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, overload
from warnings import warn
import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from typing_extensions import deprecated
from mcp.shared._httpx_utils import (
McpHttpClientFactory,
create_mcp_http_client,
)
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
)
logger = logging.getLogger(__name__)
SessionMessageOrError = SessionMessage | Exception
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
StreamReader = MemoryObjectReceiveStream[SessionMessage]
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
LAST_EVENT_ID = "last-event-id"
# Reconnection defaults
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
CONTENT_TYPE = "content-type"
ACCEPT = "accept"
JSON = "application/json"
SSE = "text/event-stream"
# Sentinel value for detecting unset optional parameters
_UNSET = object()
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.AsyncClient
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
headers: dict[str, str] | None = None # Deprecated - no longer used
sse_read_timeout: float | None = None # Deprecated - no longer used
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
@overload
def __init__(self, url: str) -> None: ...
@overload
@deprecated(
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
"Configure these on the httpx.AsyncClient instead."
)
def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
auth: httpx.Auth | None = None,
) -> None: ...
def __init__(
self,
url: str,
headers: Any = _UNSET,
timeout: Any = _UNSET,
sse_read_timeout: Any = _UNSET,
auth: Any = _UNSET,
) -> None:
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
# Check for deprecated parameters and issue runtime warning
deprecated_params: list[str] = []
if headers is not _UNSET:
deprecated_params.append("headers")
if timeout is not _UNSET:
deprecated_params.append("timeout")
if sse_read_timeout is not _UNSET:
deprecated_params.append("sse_read_timeout")
if auth is not _UNSET:
deprecated_params.append("auth")
if deprecated_params:
warn(
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
"Configure these on the httpx.AsyncClient instead.",
DeprecationWarning,
stacklevel=2,
)
self.url = url
self.session_id = None
self.protocol_version = None
def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers.
These headers will be merged with the httpx.AsyncClient's default headers,
with these MCP-specific headers taking precedence.
"""
headers: dict[str, str] = {}
# Add MCP protocol headers
headers[ACCEPT] = f"{JSON}, {SSE}"
headers[CONTENT_TYPE] = JSON
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc: # pragma: no cover
logger.warning(
f"Failed to parse initialization response as InitializeResult: {exc}"
) # pragma: no cover
logger.warning(f"Raw result: {message.root.result}")
async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
is_initialization: bool = False,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
# Handle priming events (empty data with ID) for resumability
if not sse.data:
# Call resumption callback for priming events that have an ID
if sse.id and resumption_callback:
await resumption_callback(sse.id)
return False
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
await resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc: # pragma: no cover
logger.exception("Error parsing SSE message")
await read_stream_writer.send(exc)
return False
else: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}")
return False
async def handle_get_stream(
self,
client: httpx.AsyncClient,
read_stream_writer: StreamWriter,
) -> None:
"""Handle GET stream for server-initiated messages with auto-reconnect."""
last_event_id: str | None = None
retry_interval_ms: int | None = None
attempt: int = 0
while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch
try:
if not self.session_id:
return
headers = self._prepare_headers()
if last_event_id:
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover
async with aconnect_sse(
client,
"GET",
self.url,
headers=headers,
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
async for sse in event_source.aiter_sse():
# Track last event ID for reconnection
if sse.id:
last_event_id = sse.id # pragma: no cover
# Track retry interval from server
if sse.retry is not None:
retry_interval_ms = sse.retry # pragma: no cover
await self._handle_sse_event(sse, read_stream_writer)
# Stream ended normally (server closed) - reset attempt counter
attempt = 0
except Exception as exc: # pragma: no cover
logger.debug(f"GET stream error: {exc}")
attempt += 1
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return
# Wait before reconnecting
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...")
await anyio.sleep(delay_ms / 1000.0)
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._prepare_headers()
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
async for sse in event_source.aiter_sse(): # pragma: no branch
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
break
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_headers()
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 404: # pragma: no branch
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
content_type, # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
) # pragma: no cover
async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc: # pragma: no cover
logger.exception("Error parsing JSON response")
await read_stream_writer.send(exc)
async def _handle_sse_response(
self,
response: httpx.Response,
ctx: RequestContext,
is_initialization: bool = False,
) -> None:
"""Handle SSE response from the server."""
last_event_id: str | None = None
retry_interval_ms: int | None = None
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse(): # pragma: no branch
# Track last event ID for potential reconnection
if sse.id:
last_event_id = sse.id
# Track retry interval from server
if sse.retry is not None:
retry_interval_ms = sse.retry
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
await response.aclose()
return # Normal completion, no reconnect needed
except Exception as e: # pragma: no cover
logger.debug(f"SSE stream ended: {e}")
# Stream ended without response - reconnect if we received an event with ID
if last_event_id is not None: # pragma: no branch
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
async def _handle_reconnection(
self,
ctx: RequestContext,
last_event_id: str,
retry_interval_ms: int | None = None,
attempt: int = 0,
) -> None:
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
# Bail if max retries exceeded
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
return
# Always wait - use server value or default
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
await anyio.sleep(delay_ms / 1000.0)
headers = self._prepare_headers()
headers[LAST_EVENT_ID] = last_event_id
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
try:
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
) as event_source:
event_source.response.raise_for_status()
logger.info("Reconnected to SSE stream")
# Track for potential further reconnection
reconnect_last_event_id: str = last_event_id
reconnect_retry_ms = retry_interval_ms
async for sse in event_source.aiter_sse():
if sse.id: # pragma: no branch
reconnect_last_event_id = sse.id
if sse.retry is not None:
reconnect_retry_ms = sse.retry
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
return
# Stream ended again without response - reconnect again (reset attempt counter)
logger.info("SSE stream disconnected, reconnecting...")
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
except Exception as e: # pragma: no cover
logger.debug(f"Reconnection failed: {e}")
# Try to reconnect again if we still have an event ID
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
async def _handle_unexpected_content_type(
self,
content_type: str,
read_stream_writer: StreamWriter,
) -> None: # pragma: no cover
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
logger.error(error_msg) # pragma: no cover
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
async def _send_session_terminated_error(
self,
read_stream_writer: StreamWriter,
request_id: RequestId,
) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)
async def post_writer(
self,
client: httpx.AsyncClient,
write_stream_reader: StreamReader,
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
metadata = (
session_message.metadata
if isinstance(session_message.metadata, ClientMessageMetadata)
else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug(f"Sending client message: {message}")
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
)
async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await handle_request_async()
except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._prepare_headers()
response = await client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code not in (200, 204):
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@asynccontextmanager
async def streamable_http_client(
url: str,
*,
http_client: httpx.AsyncClient | None = None,
terminate_on_close: bool = True,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
"""
Client transport for StreamableHTTP.
Args:
url: The MCP server endpoint URL.
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
client with recommended MCP timeouts will be created. To configure headers,
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
terminate_on_close: If True, send a DELETE request to terminate the session
when the context exits.
Yields:
Tuple containing:
- read_stream: Stream for reading messages from the server
- write_stream: Stream for sending messages to the server
- get_session_id_callback: Function to retrieve the current session ID
Example:
See examples/snippets/clients/ for usage patterns.
"""
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
# Determine if we need to create and manage the client
client_provided = http_client is not None
client = http_client
if client is None:
# Create default client with recommended MCP timeouts
client = create_mcp_http_client()
transport = StreamableHTTPTransport(url)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
async with contextlib.AsyncExitStack() as stack:
# Only manage client lifecycle if we created it
if not client_provided:
await stack.enter_async_context(client)
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
tg,
)
try:
yield (
read_stream,
write_stream,
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
@asynccontextmanager
@deprecated("Use `streamable_http_client` instead.")
async def streamablehttp_client(
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
# Convert timeout parameters
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
sse_read_timeout_seconds = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
# Create httpx client using the factory with old-style parameters
client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
auth=auth,
)
# Manage client lifecycle since we created it
async with client:
async with streamable_http_client(
url,
http_client=client,
terminate_on_close=terminate_on_close,
) as streams:
yield streams

View File

@@ -0,0 +1,86 @@
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol
import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
@asynccontextmanager
async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
None,
]:
"""
WebSocket client transport for MCP, symmetrical to the server version.
Connects to 'url' using the 'mcp' subprotocol, then yields:
(read_stream, write_stream)
- read_stream: As you read from this stream, you'll receive either valid
JSONRPCMessage objects or Exception objects (when validation fails).
- write_stream: Write JSONRPCMessage objects to this stream to send them
over the WebSocket to the server.
"""
# Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
# Connect using websockets, requesting the "mcp" subprotocol
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
async def ws_reader():
"""
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
and sends them into read_stream_writer.
"""
async with read_stream_writer:
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc: # pragma: no cover
# If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc)
async def ws_writer():
"""
Reads JSON-RPC messages from write_stream_reader and
sends them to the server.
"""
async with write_stream_reader:
async for session_message in write_stream_reader:
# Convert to a dict, then to JSON
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
await ws.send(json.dumps(msg_dict))
async with anyio.create_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
# Yield the receive/send streams
yield (read_stream, write_stream)
# Once the caller's 'async with' block exits, we shut down
tg.cancel_scope.cancel()