Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
85
.venv/lib/python3.11/site-packages/mcp/client/__main__.py
Normal file
85
.venv/lib/python3.11/site-packages/mcp/client/__main__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
return
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
client_info: types.Implementation | None = None,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
message_handler=message_handler,
|
||||
client_info=client_info,
|
||||
) as session:
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
def cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("command_or_url", help="Command or URL to connect to")
|
||||
parser.add_argument("args", nargs="*", help="Additional arguments")
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--env",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("KEY", "VALUE"),
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.auth.oauth2 import (
|
||||
OAuthClientProvider,
|
||||
PKCEParameters,
|
||||
TokenStorage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OAuthClientProvider",
|
||||
"OAuthFlowError",
|
||||
"OAuthRegistrationError",
|
||||
"OAuthTokenError",
|
||||
"PKCEParameters",
|
||||
"TokenStorage",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
class OAuthFlowError(Exception):
|
||||
"""Base exception for OAuth flow errors."""
|
||||
|
||||
|
||||
class OAuthTokenError(OAuthFlowError):
|
||||
"""Raised when token operations fail."""
|
||||
|
||||
|
||||
class OAuthRegistrationError(OAuthFlowError):
|
||||
"""Raised when client registration fails."""
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
OAuth client credential extensions for MCP.
|
||||
|
||||
Provides OAuth providers for machine-to-machine authentication flows:
|
||||
- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret
|
||||
- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication
|
||||
(typically using a pre-built JWT from workload identity federation)
|
||||
- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1)
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
|
||||
class ClientCredentialsOAuthProvider(OAuthClientProvider):
|
||||
"""OAuth provider for client_credentials grant with client_id + client_secret.
|
||||
|
||||
This provider sets client_info directly, bypassing dynamic client registration.
|
||||
Use this when you already have client credentials (client_id and client_secret).
|
||||
|
||||
Example:
|
||||
```python
|
||||
provider = ClientCredentialsOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
client_secret="my-client-secret",
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
storage: TokenStorage,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic",
|
||||
scopes: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize client_credentials OAuth provider.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL.
|
||||
storage: Token storage implementation.
|
||||
client_id: The OAuth client ID.
|
||||
client_secret: The OAuth client secret.
|
||||
token_endpoint_auth_method: Authentication method for token endpoint.
|
||||
Either "client_secret_basic" (default) or "client_secret_post".
|
||||
scopes: Optional space-separated list of scopes to request.
|
||||
"""
|
||||
# Build minimal client_metadata for the base class
|
||||
client_metadata = OAuthClientMetadata(
|
||||
redirect_uris=None,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
scope=scopes,
|
||||
)
|
||||
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
|
||||
# Store client_info to be set during _initialize - no dynamic registration needed
|
||||
self._fixed_client_info = OAuthClientInformationFull(
|
||||
redirect_uris=None,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
scope=scopes,
|
||||
)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Load stored tokens and set pre-configured client_info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = self._fixed_client_info
|
||||
self._initialized = True
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform client_credentials authorization."""
|
||||
return await self._exchange_token_client_credentials()
|
||||
|
||||
async def _exchange_token_client_credentials(self) -> httpx.Request:
|
||||
"""Build token exchange request for client_credentials grant."""
|
||||
token_data: dict[str, Any] = {
|
||||
"grant_type": "client_credentials",
|
||||
}
|
||||
|
||||
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# Use standard auth methods (client_secret_basic, client_secret_post, none)
|
||||
token_data, headers = self.context.prepare_token_auth(token_data, headers)
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope:
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request("POST", token_url, data=token_data, headers=headers)
|
||||
|
||||
|
||||
def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]:
|
||||
"""Create an assertion provider that returns a static JWT token.
|
||||
|
||||
Use this when you have a pre-built JWT (e.g., from workload identity federation)
|
||||
that doesn't need the audience parameter.
|
||||
|
||||
Example:
|
||||
```python
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
token: The pre-built JWT assertion string.
|
||||
|
||||
Returns:
|
||||
An async callback suitable for use as an assertion_provider.
|
||||
"""
|
||||
|
||||
async def provider(audience: str) -> str:
|
||||
return token
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
class SignedJWTParameters(BaseModel):
|
||||
"""Parameters for creating SDK-signed JWT assertions.
|
||||
|
||||
Use `create_assertion_provider()` to create an assertion provider callback
|
||||
for use with `PrivateKeyJWTOAuthProvider`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
jwt_params = SignedJWTParameters(
|
||||
issuer="my-client-id",
|
||||
subject="my-client-id",
|
||||
signing_key=private_key_pem,
|
||||
)
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=jwt_params.create_assertion_provider(),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
issuer: str = Field(description="Issuer for JWT assertions (typically client_id).")
|
||||
subject: str = Field(description="Subject identifier for JWT assertions (typically client_id).")
|
||||
signing_key: str = Field(description="Private key for JWT signing (PEM format).")
|
||||
signing_algorithm: str = Field(default="RS256", description="Algorithm for signing JWT assertions.")
|
||||
lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
|
||||
additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.")
|
||||
|
||||
def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]:
|
||||
"""Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider.
|
||||
|
||||
Returns:
|
||||
An async callback that takes the audience (authorization server issuer URL)
|
||||
and returns a signed JWT assertion.
|
||||
"""
|
||||
|
||||
async def provider(audience: str) -> str:
|
||||
now = int(time.time())
|
||||
claims: dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"sub": self.subject,
|
||||
"aud": audience,
|
||||
"exp": now + self.lifetime_seconds,
|
||||
"iat": now,
|
||||
"jti": str(uuid4()),
|
||||
}
|
||||
if self.additional_claims:
|
||||
claims.update(self.additional_claims)
|
||||
|
||||
return jwt.encode(claims, self.signing_key, algorithm=self.signing_algorithm)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
class PrivateKeyJWTOAuthProvider(OAuthClientProvider):
|
||||
"""OAuth provider for client_credentials grant with private_key_jwt authentication.
|
||||
|
||||
Uses RFC 7523 Section 2.2 for client authentication via JWT assertion.
|
||||
|
||||
The JWT assertion's audience MUST be the authorization server's issuer identifier
|
||||
(per RFC 7523bis security updates). The `assertion_provider` callback receives
|
||||
this audience value and must return a JWT with that audience.
|
||||
|
||||
**Option 1: Pre-built JWT via Workload Identity Federation**
|
||||
|
||||
In production scenarios, the JWT assertion is typically obtained from a workload
|
||||
identity provider (e.g., GCP, AWS IAM, Azure AD):
|
||||
|
||||
```python
|
||||
async def get_workload_identity_token(audience: str) -> str:
|
||||
# Fetch JWT from your identity provider
|
||||
# The JWT's audience must match the provided audience parameter
|
||||
return await fetch_token_from_identity_provider(audience=audience)
|
||||
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=get_workload_identity_token,
|
||||
)
|
||||
```
|
||||
|
||||
**Option 2: Static pre-built JWT**
|
||||
|
||||
If you have a static JWT that doesn't need the audience parameter:
|
||||
|
||||
```python
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
|
||||
)
|
||||
```
|
||||
|
||||
**Option 3: SDK-signed JWT (for testing/simple setups)**
|
||||
|
||||
For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`:
|
||||
|
||||
```python
|
||||
jwt_params = SignedJWTParameters(
|
||||
issuer="my-client-id",
|
||||
subject="my-client-id",
|
||||
signing_key=private_key_pem,
|
||||
)
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=jwt_params.create_assertion_provider(),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
storage: TokenStorage,
|
||||
client_id: str,
|
||||
assertion_provider: Callable[[str], Awaitable[str]],
|
||||
scopes: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize private_key_jwt OAuth provider.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL.
|
||||
storage: Token storage implementation.
|
||||
client_id: The OAuth client ID.
|
||||
assertion_provider: Async callback that takes the audience (authorization
|
||||
server's issuer identifier) and returns a JWT assertion. Use
|
||||
`SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs,
|
||||
`static_assertion_provider()` for pre-built JWTs, or provide your own
|
||||
callback for workload identity federation.
|
||||
scopes: Optional space-separated list of scopes to request.
|
||||
"""
|
||||
# Build minimal client_metadata for the base class
|
||||
client_metadata = OAuthClientMetadata(
|
||||
redirect_uris=None,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method="private_key_jwt",
|
||||
scope=scopes,
|
||||
)
|
||||
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
|
||||
self._assertion_provider = assertion_provider
|
||||
# Store client_info to be set during _initialize - no dynamic registration needed
|
||||
self._fixed_client_info = OAuthClientInformationFull(
|
||||
redirect_uris=None,
|
||||
client_id=client_id,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method="private_key_jwt",
|
||||
scope=scopes,
|
||||
)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Load stored tokens and set pre-configured client_info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = self._fixed_client_info
|
||||
self._initialized = True
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform client_credentials authorization with private_key_jwt."""
|
||||
return await self._exchange_token_client_credentials()
|
||||
|
||||
async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None:
|
||||
"""Add JWT assertion for client authentication to token endpoint parameters."""
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover
|
||||
|
||||
# Audience MUST be the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01
|
||||
audience = str(self.context.oauth_metadata.issuer)
|
||||
assertion = await self._assertion_provider(audience)
|
||||
|
||||
# RFC 7523 Section 2.2: client authentication via JWT
|
||||
token_data["client_assertion"] = assertion
|
||||
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
|
||||
async def _exchange_token_client_credentials(self) -> httpx.Request:
|
||||
"""Build token exchange request for client_credentials grant with private_key_jwt."""
|
||||
token_data: dict[str, Any] = {
|
||||
"grant_type": "client_credentials",
|
||||
}
|
||||
|
||||
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# Add JWT client authentication (RFC 7523 Section 2.2)
|
||||
await self._add_client_authentication_jwt(token_data=token_data)
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope:
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request("POST", token_url, data=token_data, headers=headers)
|
||||
|
||||
|
||||
class JWTParameters(BaseModel):
|
||||
"""JWT parameters."""
|
||||
|
||||
assertion: str | None = Field(
|
||||
default=None,
|
||||
description="JWT assertion for JWT authentication. "
|
||||
"Will be used instead of generating a new assertion if provided.",
|
||||
)
|
||||
|
||||
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
|
||||
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
|
||||
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
|
||||
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
|
||||
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
|
||||
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
|
||||
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
|
||||
|
||||
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
|
||||
if self.assertion is not None:
|
||||
# Prebuilt JWT (e.g. acquired out-of-band)
|
||||
assertion = self.assertion
|
||||
else:
|
||||
if not self.jwt_signing_key:
|
||||
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
|
||||
if not self.issuer:
|
||||
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
|
||||
if not self.subject:
|
||||
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
|
||||
|
||||
audience = self.audience if self.audience else with_audience_fallback
|
||||
if not audience:
|
||||
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
|
||||
|
||||
now = int(time.time())
|
||||
claims: dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"sub": self.subject,
|
||||
"aud": audience,
|
||||
"exp": now + self.jwt_lifetime_seconds,
|
||||
"iat": now,
|
||||
"jti": str(uuid4()),
|
||||
}
|
||||
claims.update(self.claims or {})
|
||||
|
||||
assertion = jwt.encode(
|
||||
claims,
|
||||
self.jwt_signing_key,
|
||||
algorithm=self.jwt_signing_algorithm or "RS256",
|
||||
)
|
||||
return assertion
|
||||
|
||||
|
||||
class RFC7523OAuthClientProvider(OAuthClientProvider):
|
||||
"""OAuth client provider for RFC 7523 jwt-bearer grant.
|
||||
|
||||
.. deprecated::
|
||||
Use :class:`ClientCredentialsOAuthProvider` for client_credentials with
|
||||
client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for
|
||||
client_credentials with private_key_jwt authentication instead.
|
||||
|
||||
This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1)
|
||||
where the JWT itself is the authorization grant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
jwt_parameters: JWTParameters | None = None,
|
||||
) -> None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"RFC7523OAuthClientProvider is deprecated. Use ClientCredentialsOAuthProvider "
|
||||
"or PrivateKeyJWTOAuthProvider instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
|
||||
self.jwt_parameters = jwt_parameters
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
|
||||
) -> httpx.Request: # pragma: no cover
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
token_data = token_data or {}
|
||||
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
|
||||
self._add_client_authentication_jwt(token_data=token_data)
|
||||
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
|
||||
"""Perform the authorization flow."""
|
||||
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
|
||||
token_request = await self._exchange_token_jwt_bearer()
|
||||
return token_request
|
||||
else:
|
||||
return await super()._perform_authorization()
|
||||
|
||||
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
|
||||
"""Add JWT assertion for client authentication to token endpoint parameters."""
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
|
||||
token_data["client_assertion"] = assertion
|
||||
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
# We need to set the audience to the resource server, the audience is difference from the one in claims
|
||||
# it represents the resource server that will validate the token
|
||||
token_data["audience"] = self.context.get_resource_url()
|
||||
|
||||
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
|
||||
"""Build token exchange request for JWT bearer grant."""
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
token_data = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
"assertion": assertion,
|
||||
}
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
616
.venv/lib/python3.11/site-packages/mcp/client/auth/oauth2.py
Normal file
616
.venv/lib/python3.11/site-packages/mcp/client/auth/oauth2.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import quote, urlencode, urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
|
||||
from mcp.client.auth.utils import (
|
||||
build_oauth_authorization_server_metadata_discovery_urls,
|
||||
build_protected_resource_metadata_discovery_urls,
|
||||
create_client_info_from_metadata_url,
|
||||
create_client_registration_request,
|
||||
create_oauth_metadata_request,
|
||||
extract_field_from_www_auth,
|
||||
extract_resource_metadata_from_www_auth,
|
||||
extract_scope_from_www_auth,
|
||||
get_client_metadata_scopes,
|
||||
handle_auth_metadata_response,
|
||||
handle_protected_resource_response,
|
||||
handle_registration_response,
|
||||
handle_token_response_scopes,
|
||||
is_valid_client_metadata_url,
|
||||
should_use_client_metadata_url,
|
||||
)
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.shared.auth_utils import (
|
||||
calculate_token_expiry,
|
||||
check_resource_allowed,
|
||||
resource_url_from_server_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PKCEParameters(BaseModel):
|
||||
"""PKCE (Proof Key for Code Exchange) parameters."""
|
||||
|
||||
code_verifier: str = Field(..., min_length=43, max_length=128)
|
||||
code_challenge: str = Field(..., min_length=43, max_length=128)
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PKCEParameters":
|
||||
"""Generate new PKCE parameters."""
|
||||
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
|
||||
|
||||
|
||||
class TokenStorage(Protocol):
|
||||
"""Protocol for token storage implementations."""
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
"""Get stored tokens."""
|
||||
...
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
"""Store tokens."""
|
||||
...
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
"""Get stored client information."""
|
||||
...
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""Store client information."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthContext:
|
||||
"""OAuth flow context."""
|
||||
|
||||
server_url: str
|
||||
client_metadata: OAuthClientMetadata
|
||||
storage: TokenStorage
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
|
||||
timeout: float = 300.0
|
||||
client_metadata_url: str | None = None
|
||||
|
||||
# Discovered metadata
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None = None
|
||||
oauth_metadata: OAuthMetadata | None = None
|
||||
auth_server_url: str | None = None
|
||||
protocol_version: str | None = None
|
||||
|
||||
# Client registration
|
||||
client_info: OAuthClientInformationFull | None = None
|
||||
|
||||
# Token management
|
||||
current_tokens: OAuthToken | None = None
|
||||
token_expiry_time: float | None = None
|
||||
|
||||
# State
|
||||
lock: anyio.Lock = field(default_factory=anyio.Lock)
|
||||
|
||||
def get_authorization_base_url(self, server_url: str) -> str:
|
||||
"""Extract base URL by removing path component."""
|
||||
parsed = urlparse(server_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def update_token_expiry(self, token: OAuthToken) -> None:
|
||||
"""Update token expiry time using shared util function."""
|
||||
self.token_expiry_time = calculate_token_expiry(token.expires_in)
|
||||
|
||||
def is_token_valid(self) -> bool:
|
||||
"""Check if current token is valid."""
|
||||
return bool(
|
||||
self.current_tokens
|
||||
and self.current_tokens.access_token
|
||||
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
|
||||
)
|
||||
|
||||
def can_refresh_token(self) -> bool:
|
||||
"""Check if token can be refreshed."""
|
||||
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""Clear current tokens."""
|
||||
self.current_tokens = None
|
||||
self.token_expiry_time = None
|
||||
|
||||
def get_resource_url(self) -> str:
|
||||
"""Get resource URL for RFC 8707.
|
||||
|
||||
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
|
||||
"""
|
||||
resource = resource_url_from_server_url(self.server_url)
|
||||
|
||||
# If PRM provides a resource that's a valid parent, use it
|
||||
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
|
||||
prm_resource = str(self.protected_resource_metadata.resource)
|
||||
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
|
||||
resource = prm_resource
|
||||
|
||||
return resource
|
||||
|
||||
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
|
||||
"""Determine if the resource parameter should be included in OAuth requests.
|
||||
|
||||
Returns True if:
|
||||
- Protected resource metadata is available, OR
|
||||
- MCP-Protocol-Version header is 2025-06-18 or later
|
||||
"""
|
||||
# If we have protected resource metadata, include the resource param
|
||||
if self.protected_resource_metadata is not None:
|
||||
return True
|
||||
|
||||
# If no protocol version provided, don't include resource param
|
||||
if not protocol_version:
|
||||
return False
|
||||
|
||||
# Check if protocol version is 2025-06-18 or later
|
||||
# Version format is YYYY-MM-DD, so string comparison works
|
||||
return protocol_version >= "2025-06-18"
|
||||
|
||||
def prepare_token_auth(
|
||||
self, data: dict[str, str], headers: dict[str, str] | None = None
|
||||
) -> tuple[dict[str, str], dict[str, str]]:
|
||||
"""Prepare authentication for token requests.
|
||||
|
||||
Args:
|
||||
data: The form data to send
|
||||
headers: Optional headers dict to update
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_data, updated_headers)
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {} # pragma: no cover
|
||||
|
||||
if not self.client_info:
|
||||
return data, headers # pragma: no cover
|
||||
|
||||
auth_method = self.client_info.token_endpoint_auth_method
|
||||
|
||||
if auth_method == "client_secret_basic" and self.client_info.client_id and self.client_info.client_secret:
|
||||
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
|
||||
encoded_id = quote(self.client_info.client_id, safe="")
|
||||
encoded_secret = quote(self.client_info.client_secret, safe="")
|
||||
credentials = f"{encoded_id}:{encoded_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
# Don't include client_secret in body for basic auth
|
||||
data = {k: v for k, v in data.items() if k != "client_secret"}
|
||||
elif auth_method == "client_secret_post" and self.client_info.client_secret:
|
||||
# Include client_secret in request body
|
||||
data["client_secret"] = self.client_info.client_secret
|
||||
# For auth_method == "none", don't add any client_secret
|
||||
|
||||
return data, headers
|
||||
|
||||
|
||||
class OAuthClientProvider(httpx.Auth):
|
||||
"""
|
||||
OAuth2 authentication for httpx.
|
||||
Handles OAuth flow with automatic client registration and token storage.
|
||||
"""
|
||||
|
||||
requires_response_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
client_metadata_url: str | None = None,
|
||||
):
|
||||
"""Initialize OAuth2 authentication.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL.
|
||||
client_metadata: OAuth client metadata for registration.
|
||||
storage: Token storage implementation.
|
||||
redirect_handler: Handler for authorization redirects.
|
||||
callback_handler: Handler for authorization callbacks.
|
||||
timeout: Timeout for the OAuth flow.
|
||||
client_metadata_url: URL-based client ID. When provided and the server
|
||||
advertises client_id_metadata_document_supported=true, this URL will be
|
||||
used as the client_id instead of performing dynamic client registration.
|
||||
Must be a valid HTTPS URL with a non-root pathname.
|
||||
|
||||
Raises:
|
||||
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
|
||||
with a non-root pathname.
|
||||
"""
|
||||
# Validate client_metadata_url if provided
|
||||
if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url):
|
||||
raise ValueError(
|
||||
f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}"
|
||||
)
|
||||
|
||||
self.context = OAuthContext(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=timeout,
|
||||
client_metadata_url=client_metadata_url,
|
||||
)
|
||||
self._initialized = False
|
||||
|
||||
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
self.context.protected_resource_metadata = metadata
|
||||
if metadata.authorization_servers: # pragma: no branch
|
||||
self.context.auth_server_url = str(metadata.authorization_servers[0])
|
||||
return True
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
|
||||
return False
|
||||
elif response.status_code == 404: # pragma: no cover
|
||||
# Not found - try next URL in fallback chain
|
||||
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
|
||||
return False
|
||||
else:
|
||||
# Other error - fail immediately
|
||||
raise OAuthFlowError(
|
||||
f"Protected Resource Metadata request failed: {response.status_code}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform the authorization flow."""
|
||||
auth_code, code_verifier = await self._perform_authorization_code_grant()
|
||||
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
|
||||
return token_request
|
||||
|
||||
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
|
||||
"""Perform the authorization redirect and get auth code."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.redirect_handler:
|
||||
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.callback_handler:
|
||||
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
|
||||
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
auth_endpoint = urljoin(auth_base_url, "/authorize")
|
||||
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
|
||||
|
||||
# Generate PKCE parameters
|
||||
pkce_params = PKCEParameters.generate()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
auth_params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"state": state,
|
||||
"code_challenge": pkce_params.code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
auth_params["scope"] = self.context.client_metadata.scope
|
||||
|
||||
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
await self.context.redirect_handler(authorization_url)
|
||||
|
||||
# Wait for callback
|
||||
auth_code, returned_state = await self.context.callback_handler()
|
||||
|
||||
if returned_state is None or not secrets.compare_digest(returned_state, state):
|
||||
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError("No authorization code received") # pragma: no cover
|
||||
|
||||
# Return auth code and code verifier for token exchange
|
||||
return auth_code, pkce_params.code_verifier
|
||||
|
||||
def _get_token_endpoint(self) -> str:
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint)
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
return token_url
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
|
||||
) -> httpx.Request:
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
token_data = token_data or {}
|
||||
token_data.update(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
)
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
# Prepare authentication based on preferred method
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
token_data, headers = self.context.prepare_token_auth(token_data, headers)
|
||||
|
||||
return httpx.Request("POST", token_url, data=token_data, headers=headers)
|
||||
|
||||
async def _handle_token_response(self, response: httpx.Response) -> None:
|
||||
"""Handle token exchange response."""
|
||||
if response.status_code != 200:
|
||||
body = await response.aread() # pragma: no cover
|
||||
body_text = body.decode("utf-8") # pragma: no cover
|
||||
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
|
||||
|
||||
# Parse and validate response with scope validation
|
||||
token_response = await handle_token_response_scopes(response)
|
||||
|
||||
# Store tokens in context
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
async def _refresh_token(self) -> httpx.Request:
|
||||
"""Build token refresh request."""
|
||||
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
|
||||
raise OAuthTokenError("No refresh token available") # pragma: no cover
|
||||
|
||||
if not self.context.client_info or not self.context.client_info.client_id:
|
||||
raise OAuthTokenError("No client info available") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
|
||||
refresh_data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self.context.current_tokens.refresh_token,
|
||||
"client_id": self.context.client_info.client_id,
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
# Prepare authentication based on preferred method
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
|
||||
|
||||
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
|
||||
|
||||
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
|
||||
"""Handle token refresh response. Returns True if successful."""
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token refresh failed: {response.status_code}")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
return True
|
||||
except ValidationError:
|
||||
logger.exception("Invalid refresh response")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
async def _initialize(self) -> None: # pragma: no cover
|
||||
"""Load stored tokens and client info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = await self.context.storage.get_client_info()
|
||||
self._initialized = True
|
||||
|
||||
def _add_auth_header(self, request: httpx.Request) -> None:
|
||||
"""Add authorization header to request if we have valid tokens."""
|
||||
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
|
||||
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
|
||||
|
||||
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
|
||||
content = await response.aread()
|
||||
metadata = OAuthMetadata.model_validate_json(content)
|
||||
self.context.oauth_metadata = metadata
|
||||
|
||||
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||
"""HTTPX auth flow integration."""
|
||||
async with self.context.lock:
|
||||
if not self._initialized:
|
||||
await self._initialize() # pragma: no cover
|
||||
|
||||
# Capture protocol version from request headers
|
||||
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
|
||||
|
||||
if not self.context.is_token_valid() and self.context.can_refresh_token():
|
||||
# Try to refresh token
|
||||
refresh_request = await self._refresh_token() # pragma: no cover
|
||||
refresh_response = yield refresh_request # pragma: no cover
|
||||
|
||||
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
|
||||
# Refresh failed, need full re-authentication
|
||||
self._initialized = False
|
||||
|
||||
if self.context.is_token_valid():
|
||||
self._add_auth_header(request)
|
||||
|
||||
response = yield request
|
||||
|
||||
if response.status_code == 401:
|
||||
# Perform full OAuth flow
|
||||
try:
|
||||
# OAuth flow must be inline due to generator constraints
|
||||
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
|
||||
|
||||
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
|
||||
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url, self.context.server_url
|
||||
)
|
||||
|
||||
for url in prm_discovery_urls: # pragma: no branch
|
||||
discovery_request = create_oauth_metadata_request(url)
|
||||
|
||||
discovery_response = yield discovery_request # sending request
|
||||
|
||||
prm = await handle_protected_resource_response(discovery_response)
|
||||
if prm:
|
||||
self.context.protected_resource_metadata = prm
|
||||
|
||||
# todo: try all authorization_servers to find the OASM
|
||||
assert (
|
||||
len(prm.authorization_servers) > 0
|
||||
) # this is always true as authorization_servers has a min length of 1
|
||||
|
||||
self.context.auth_server_url = str(prm.authorization_servers[0])
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Protected resource metadata discovery failed: {url}")
|
||||
|
||||
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
self.context.auth_server_url, self.context.server_url
|
||||
)
|
||||
|
||||
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
|
||||
for url in asm_discovery_urls: # pragma: no cover
|
||||
oauth_metadata_request = create_oauth_metadata_request(url)
|
||||
oauth_metadata_response = yield oauth_metadata_request
|
||||
|
||||
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
|
||||
if not ok:
|
||||
break
|
||||
if ok and asm:
|
||||
self.context.oauth_metadata = asm
|
||||
break
|
||||
else:
|
||||
logger.debug(f"OAuth metadata discovery failed: {url}")
|
||||
|
||||
# Step 3: Apply scope selection strategy
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
extract_scope_from_www_auth(response),
|
||||
self.context.protected_resource_metadata,
|
||||
self.context.oauth_metadata,
|
||||
)
|
||||
|
||||
# Step 4: Register client or use URL-based client ID (CIMD)
|
||||
if not self.context.client_info:
|
||||
if should_use_client_metadata_url(
|
||||
self.context.oauth_metadata, self.context.client_metadata_url
|
||||
):
|
||||
# Use URL-based client ID (CIMD)
|
||||
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
|
||||
client_information = create_client_info_from_metadata_url(
|
||||
self.context.client_metadata_url, # type: ignore[arg-type]
|
||||
redirect_uris=self.context.client_metadata.redirect_uris,
|
||||
)
|
||||
self.context.client_info = client_information
|
||||
await self.context.storage.set_client_info(client_information)
|
||||
else:
|
||||
# Fallback to Dynamic Client Registration
|
||||
registration_request = create_client_registration_request(
|
||||
self.context.oauth_metadata,
|
||||
self.context.client_metadata,
|
||||
self.context.get_authorization_base_url(self.context.server_url),
|
||||
)
|
||||
registration_response = yield registration_request
|
||||
client_information = await handle_registration_response(registration_response)
|
||||
self.context.client_info = client_information
|
||||
await self.context.storage.set_client_info(client_information)
|
||||
|
||||
# Step 5: Perform authorization and complete token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
elif response.status_code == 403:
|
||||
# Step 1: Extract error field from WWW-Authenticate header
|
||||
error = extract_field_from_www_auth(response, "error")
|
||||
|
||||
# Step 2: Check if we need to step-up authorization
|
||||
if error == "insufficient_scope": # pragma: no branch
|
||||
try:
|
||||
# Step 2a: Update the required scopes
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
|
||||
)
|
||||
|
||||
# Step 2b: Perform (re-)authorization and token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
336
.venv/lib/python3.11/site-packages/mcp/client/auth/utils.py
Normal file
336
.venv/lib/python3.11/site-packages/mcp/client/auth/utils.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import Request, Response
|
||||
from pydantic import AnyUrl, ValidationError
|
||||
|
||||
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
|
||||
"""
|
||||
Extract field from WWW-Authenticate header.
|
||||
|
||||
Returns:
|
||||
Field value if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
www_auth_header = response.headers.get("WWW-Authenticate")
|
||||
if not www_auth_header:
|
||||
return None
|
||||
|
||||
# Pattern matches: field_name="value" or field_name=value (unquoted)
|
||||
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
|
||||
match = re.search(pattern, www_auth_header)
|
||||
|
||||
if match:
|
||||
# Return quoted value if present, otherwise unquoted value
|
||||
return match.group(1) or match.group(2)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_scope_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract scope parameter from WWW-Authenticate header as per RFC6750.
|
||||
|
||||
Returns:
|
||||
Scope string if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
return extract_field_from_www_auth(response, "scope")
|
||||
|
||||
|
||||
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
|
||||
|
||||
Returns:
|
||||
Resource metadata URL if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
if not response or response.status_code != 401:
|
||||
return None # pragma: no cover
|
||||
|
||||
return extract_field_from_www_auth(response, "resource_metadata")
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build ordered list of URLs to try for protected resource metadata discovery.
|
||||
|
||||
Per SEP-985, the client MUST:
|
||||
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||||
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||||
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||||
|
||||
Args:
|
||||
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
|
||||
server_url: server url
|
||||
|
||||
Returns:
|
||||
Ordered list of URLs to try for discovery
|
||||
"""
|
||||
urls: list[str] = []
|
||||
|
||||
# Priority 1: WWW-Authenticate header with resource_metadata parameter
|
||||
if www_auth_url:
|
||||
urls.append(www_auth_url)
|
||||
|
||||
# Priority 2-3: Well-known URIs (RFC 9728)
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Priority 2: Path-based well-known URI (if server has a path component)
|
||||
if parsed.path and parsed.path != "/":
|
||||
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
|
||||
urls.append(path_based_url)
|
||||
|
||||
# Priority 3: Root-based well-known URI
|
||||
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
urls.append(root_based_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def get_client_metadata_scopes(
|
||||
www_authenticate_scope: str | None,
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None,
|
||||
authorization_server_metadata: OAuthMetadata | None = None,
|
||||
) -> str | None:
|
||||
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
|
||||
# Per MCP spec, scope selection priority order:
|
||||
# 1. Use scope from WWW-Authenticate header (if provided)
|
||||
# 2. Use all scopes from PRM scopes_supported (if available)
|
||||
# 3. Omit scope parameter if neither is available
|
||||
|
||||
if www_authenticate_scope is not None:
|
||||
# Priority 1: WWW-Authenticate header scope
|
||||
return www_authenticate_scope
|
||||
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
|
||||
# Priority 2: PRM scopes_supported
|
||||
return " ".join(protected_resource_metadata.scopes_supported)
|
||||
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
|
||||
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
|
||||
else:
|
||||
# Priority 3: Omit scope parameter
|
||||
return None
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Generate ordered list of (url, type) tuples for discovery attempts.
|
||||
|
||||
Args:
|
||||
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
|
||||
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
|
||||
"""
|
||||
|
||||
if not auth_server_url:
|
||||
# Legacy path using the 2025-03-26 spec:
|
||||
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
|
||||
parsed = urlparse(server_url)
|
||||
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
|
||||
|
||||
urls: list[str] = []
|
||||
parsed = urlparse(auth_server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# RFC 8414: Path-aware OAuth discovery
|
||||
if parsed.path and parsed.path != "/":
|
||||
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oauth_path))
|
||||
|
||||
# RFC 8414 section 5: Path-aware OIDC discovery
|
||||
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
|
||||
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
return urls
|
||||
|
||||
# OAuth root
|
||||
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
|
||||
|
||||
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
async def handle_protected_resource_response(
|
||||
response: Response,
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
return metadata
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
return None
|
||||
else:
|
||||
# Not found - try next URL in fallback chain
|
||||
return None
|
||||
|
||||
|
||||
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
asm = OAuthMetadata.model_validate_json(content)
|
||||
return True, asm
|
||||
except ValidationError: # pragma: no cover
|
||||
return True, None
|
||||
elif response.status_code < 400 or response.status_code >= 500:
|
||||
return False, None # Non-4XX error, stop trying
|
||||
return True, None
|
||||
|
||||
|
||||
def create_oauth_metadata_request(url: str) -> Request:
|
||||
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
|
||||
|
||||
|
||||
def create_client_registration_request(
|
||||
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
|
||||
) -> Request:
|
||||
"""Build registration request or skip if already registered."""
|
||||
|
||||
if auth_server_metadata and auth_server_metadata.registration_endpoint:
|
||||
registration_url = str(auth_server_metadata.registration_endpoint)
|
||||
else:
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
|
||||
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
|
||||
|
||||
|
||||
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
|
||||
"""Handle registration response."""
|
||||
if response.status_code not in (200, 201):
|
||||
await response.aread()
|
||||
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
client_info = OAuthClientInformationFull.model_validate_json(content)
|
||||
return client_info
|
||||
# self.context.client_info = client_info
|
||||
# await self.context.storage.set_client_info(client_info)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthRegistrationError(f"Invalid registration response: {e}")
|
||||
|
||||
|
||||
def is_valid_client_metadata_url(url: str | None) -> bool:
|
||||
"""Validate that a URL is suitable for use as a client_id (CIMD).
|
||||
|
||||
The URL must be HTTPS with a non-root pathname.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
Returns:
|
||||
True if the URL is a valid HTTPS URL with a non-root pathname
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.scheme == "https" and parsed.path not in ("", "/")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def should_use_client_metadata_url(
|
||||
oauth_metadata: OAuthMetadata | None,
|
||||
client_metadata_url: str | None,
|
||||
) -> bool:
|
||||
"""Determine if URL-based client ID (CIMD) should be used instead of DCR.
|
||||
|
||||
URL-based client IDs should be used when:
|
||||
1. The server advertises client_id_metadata_document_supported=true
|
||||
2. The client has a valid client_metadata_url configured
|
||||
|
||||
Args:
|
||||
oauth_metadata: OAuth authorization server metadata
|
||||
client_metadata_url: URL-based client ID (already validated)
|
||||
|
||||
Returns:
|
||||
True if CIMD should be used, False if DCR should be used
|
||||
"""
|
||||
if not client_metadata_url:
|
||||
return False
|
||||
|
||||
if not oauth_metadata:
|
||||
return False
|
||||
|
||||
return oauth_metadata.client_id_metadata_document_supported is True
|
||||
|
||||
|
||||
def create_client_info_from_metadata_url(
|
||||
client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Create client information using a URL-based client ID (CIMD).
|
||||
|
||||
When using URL-based client IDs, the URL itself becomes the client_id
|
||||
and no client_secret is used (token_endpoint_auth_method="none").
|
||||
|
||||
Args:
|
||||
client_metadata_url: The URL to use as the client_id
|
||||
redirect_uris: The redirect URIs from the client metadata (passed through for
|
||||
compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata)
|
||||
|
||||
Returns:
|
||||
OAuthClientInformationFull with the URL as client_id
|
||||
"""
|
||||
return OAuthClientInformationFull(
|
||||
client_id=client_metadata_url,
|
||||
token_endpoint_auth_method="none",
|
||||
redirect_uris=redirect_uris,
|
||||
)
|
||||
|
||||
|
||||
async def handle_token_response_scopes(
|
||||
response: Response,
|
||||
) -> OAuthToken:
|
||||
"""Parse and validate token response with optional scope validation.
|
||||
|
||||
Parses token response JSON. Callers should check response.status_code before calling.
|
||||
|
||||
Args:
|
||||
response: HTTP response from token endpoint (status already checked by caller)
|
||||
|
||||
Returns:
|
||||
Validated OAuthToken model
|
||||
|
||||
Raises:
|
||||
OAuthTokenError: If response JSON is invalid
|
||||
"""
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
return token_response
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthTokenError(f"Invalid token response: {e}")
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Experimental client features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from mcp.client.experimental.tasks import ExperimentalClientFeatures
|
||||
|
||||
__all__ = ["ExperimentalClientFeatures"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Experimental task handler protocols for server -> client requests.
|
||||
|
||||
This module provides Protocol types and default handlers for when servers
|
||||
send task-related requests to clients (the reverse of normal client -> server flow).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Use cases:
|
||||
- Server sends task-augmented sampling/elicitation request to client
|
||||
- Client creates a local task, spawns background work, returns CreateTaskResult
|
||||
- Server polls client's task status via tasks/get, tasks/result, etc.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
|
||||
class GetTaskHandlerFnT(Protocol):
|
||||
"""Handler for tasks/get requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskRequestParams,
|
||||
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class GetTaskResultHandlerFnT(Protocol):
|
||||
"""Handler for tasks/result requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskPayloadRequestParams,
|
||||
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ListTasksHandlerFnT(Protocol):
|
||||
"""Handler for tasks/list requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.PaginatedRequestParams | None,
|
||||
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class CancelTaskHandlerFnT(Protocol):
|
||||
"""Handler for tasks/cancel requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CancelTaskRequestParams,
|
||||
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class TaskAugmentedSamplingFnT(Protocol):
|
||||
"""Handler for task-augmented sampling/createMessage requests from server.
|
||||
|
||||
When server sends a CreateMessageRequest with task field, this callback
|
||||
is invoked. The callback should create a task, spawn background work,
|
||||
and return CreateTaskResult immediately.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class TaskAugmentedElicitationFnT(Protocol):
|
||||
"""Handler for task-augmented elicitation/create requests from server.
|
||||
|
||||
When server sends an ElicitRequest with task field, this callback
|
||||
is invoked. The callback should create a task, spawn background work,
|
||||
and return CreateTaskResult immediately.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
async def default_get_task_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskRequestParams,
|
||||
) -> types.GetTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/get not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_get_task_result_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskPayloadRequestParams,
|
||||
) -> types.GetTaskPayloadResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/result not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_list_tasks_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.PaginatedRequestParams | None,
|
||||
) -> types.ListTasksResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/list not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_cancel_task_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CancelTaskRequestParams,
|
||||
) -> types.CancelTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/cancel not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_task_augmented_sampling(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Task-augmented sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_task_augmented_elicitation(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Task-augmented elicitation not supported",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalTaskHandlers:
|
||||
"""Container for experimental task handlers.
|
||||
|
||||
Groups all task-related handlers that handle server -> client requests.
|
||||
This includes both pure task requests (get, list, cancel, result) and
|
||||
task-augmented request handlers (sampling, elicitation with task field).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
handlers = ExperimentalTaskHandlers(
|
||||
get_task=my_get_task_handler,
|
||||
list_tasks=my_list_tasks_handler,
|
||||
)
|
||||
session = ClientSession(..., experimental_task_handlers=handlers)
|
||||
"""
|
||||
|
||||
# Pure task request handlers
|
||||
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
|
||||
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
|
||||
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
|
||||
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
|
||||
|
||||
# Task-augmented request handlers
|
||||
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
|
||||
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
|
||||
|
||||
def build_capability(self) -> types.ClientTasksCapability | None:
|
||||
"""Build ClientTasksCapability from the configured handlers.
|
||||
|
||||
Returns a capability object that reflects which handlers are configured
|
||||
(i.e., not using the default "not supported" handlers).
|
||||
|
||||
Returns:
|
||||
ClientTasksCapability if any handlers are provided, None otherwise
|
||||
"""
|
||||
has_list = self.list_tasks is not default_list_tasks_handler
|
||||
has_cancel = self.cancel_task is not default_cancel_task_handler
|
||||
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
|
||||
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
|
||||
|
||||
# If no handlers are provided, return None
|
||||
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
|
||||
return None
|
||||
|
||||
# Build requests capability if any request handlers are provided
|
||||
requests_capability: types.ClientTasksRequestsCapability | None = None
|
||||
if has_sampling or has_elicitation:
|
||||
requests_capability = types.ClientTasksRequestsCapability(
|
||||
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
|
||||
if has_sampling
|
||||
else None,
|
||||
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
|
||||
if has_elicitation
|
||||
else None,
|
||||
)
|
||||
|
||||
return types.ClientTasksCapability(
|
||||
list=types.TasksListCapability() if has_list else None,
|
||||
cancel=types.TasksCancelCapability() if has_cancel else None,
|
||||
requests=requests_capability,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handles_request(request: types.ServerRequest) -> bool:
|
||||
"""Check if this handler handles the given request type."""
|
||||
return isinstance(
|
||||
request.root,
|
||||
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
ctx: RequestContext["ClientSession", Any],
|
||||
responder: RequestResponder[types.ServerRequest, types.ClientResult],
|
||||
) -> None:
|
||||
"""Handle a task-related request from the server.
|
||||
|
||||
Call handles_request() first to check if this handler can handle the request.
|
||||
"""
|
||||
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
|
||||
types.ClientResult | types.ErrorData
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.GetTaskRequest(params=params):
|
||||
response = await self.get_task(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.GetTaskPayloadRequest(params=params):
|
||||
response = await self.get_task_result(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListTasksRequest(params=params):
|
||||
response = await self.list_tasks(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.CancelTaskRequest(params=params):
|
||||
response = await self.cancel_task(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case _: # pragma: no cover
|
||||
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
default_task_augmented_sampling_callback = default_task_augmented_sampling
|
||||
default_task_augmented_elicitation_callback = default_task_augmented_elicitation
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Experimental client-side task support.
|
||||
|
||||
This module provides client methods for interacting with MCP tasks.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
# Call a tool as a task
|
||||
result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"})
|
||||
task_id = result.task.taskId
|
||||
|
||||
# Get task status
|
||||
status = await session.experimental.get_task(task_id)
|
||||
|
||||
# Get task result when complete
|
||||
if status.status == "completed":
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
|
||||
# List all tasks
|
||||
tasks = await session.experimental.list_tasks()
|
||||
|
||||
# Cancel a task
|
||||
await session.experimental.cancel_task(task_id)
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.experimental.tasks.polling import poll_until_terminal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
ResultT = TypeVar("ResultT", bound=types.Result)
|
||||
|
||||
|
||||
class ExperimentalClientFeatures:
|
||||
"""
|
||||
Experimental client features for tasks and other experimental APIs.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Access via session.experimental:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
"""
|
||||
|
||||
def __init__(self, session: "ClientSession") -> None:
|
||||
self._session = session
|
||||
|
||||
async def call_tool_as_task(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CreateTaskResult:
|
||||
"""Call a tool as a task, returning a CreateTaskResult for polling.
|
||||
|
||||
This is a convenience method for calling tools that support task execution.
|
||||
The server will return a task reference instead of the immediate result,
|
||||
which can then be polled via `get_task()` and retrieved via `get_task_result()`.
|
||||
|
||||
Args:
|
||||
name: The tool name
|
||||
arguments: Tool arguments
|
||||
ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute)
|
||||
meta: Optional metadata to include in the request
|
||||
|
||||
Returns:
|
||||
CreateTaskResult containing the task reference
|
||||
|
||||
Example:
|
||||
# Create task
|
||||
result = await session.experimental.call_tool_as_task(
|
||||
"long_running_tool", {"input": "data"}
|
||||
)
|
||||
task_id = result.task.taskId
|
||||
|
||||
# Poll for completion
|
||||
while True:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
if status.status == "completed":
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get result
|
||||
final = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
_meta: types.RequestParams.Meta | None = None
|
||||
if meta is not None:
|
||||
_meta = types.RequestParams.Meta(**meta)
|
||||
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
params=types.CallToolRequestParams(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
_meta=_meta,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
async def get_task(self, task_id: str) -> types.GetTaskResult:
|
||||
"""
|
||||
Get the current status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
GetTaskResult containing the task status and metadata
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetTaskRequest(
|
||||
params=types.GetTaskRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
types.GetTaskResult,
|
||||
)
|
||||
|
||||
async def get_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result_type: type[ResultT],
|
||||
) -> ResultT:
|
||||
"""
|
||||
Get the result of a completed task.
|
||||
|
||||
The result type depends on the original request type:
|
||||
- tools/call tasks return CallToolResult
|
||||
- Other request types return their corresponding result type
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result_type: The expected result type (e.g., CallToolResult)
|
||||
|
||||
Returns:
|
||||
The task result, validated against result_type
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetTaskPayloadRequest(
|
||||
params=types.GetTaskPayloadRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
result_type,
|
||||
)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> types.ListTasksResult:
|
||||
"""
|
||||
List all tasks.
|
||||
|
||||
Args:
|
||||
cursor: Optional pagination cursor
|
||||
|
||||
Returns:
|
||||
ListTasksResult containing tasks and optional next cursor
|
||||
"""
|
||||
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListTasksRequest(params=params),
|
||||
),
|
||||
types.ListTasksResult,
|
||||
)
|
||||
|
||||
async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
CancelTaskResult with the updated task state
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CancelTaskRequest(
|
||||
params=types.CancelTaskRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
types.CancelTaskResult,
|
||||
)
|
||||
|
||||
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
|
||||
"""
|
||||
Poll a task until it reaches a terminal status.
|
||||
|
||||
Yields GetTaskResult for each poll, allowing the caller to react to
|
||||
status changes (e.g., handle input_required). Exits when task reaches
|
||||
a terminal status (completed, failed, cancelled).
|
||||
|
||||
Respects the pollInterval hint from the server.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
|
||||
Example:
|
||||
async for status in session.experimental.poll_task(task_id):
|
||||
print(f"Status: {status.status}")
|
||||
if status.status == "input_required":
|
||||
# Handle elicitation request via tasks/result
|
||||
pass
|
||||
|
||||
# Task is now terminal, get the result
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
async for status in poll_until_terminal(self.get_task, task_id):
|
||||
yield status
|
||||
615
.venv/lib/python3.11/site-packages/mcp/client/session.py
Normal file
615
.venv/lib/python3.11/site-packages/mcp/client/session.py
Normal file
@@ -0,0 +1,615 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol, overload
|
||||
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.experimental import ExperimentalClientFeatures
|
||||
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
|
||||
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ElicitationFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
async def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_elicitation_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
) -> types.ElicitResult | types.ErrorData:
|
||||
return types.ErrorData( # pragma: no cover
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Elicitation not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
elicitation_callback: ElicitationFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
*,
|
||||
sampling_capabilities: types.SamplingCapability | None = None,
|
||||
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._sampling_capabilities = sampling_capabilities
|
||||
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
self._server_capabilities: types.ServerCapabilities | None = None
|
||||
self._experimental_features: ExperimentalClientFeatures | None = None
|
||||
|
||||
# Experimental: Task handlers (use defaults if not provided)
|
||||
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = (
|
||||
(self._sampling_capabilities or types.SamplingCapability())
|
||||
if self._sampling_callback is not _default_sampling_callback
|
||||
else None
|
||||
)
|
||||
elicitation = (
|
||||
types.ElicitationCapability(
|
||||
form=types.FormElicitationCapability(),
|
||||
url=types.UrlElicitationCapability(),
|
||||
)
|
||||
if self._elicitation_callback is not _default_elicitation_callback
|
||||
else None
|
||||
)
|
||||
roots = (
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
types.RootsCapability(listChanged=True)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
elicitation=elicitation,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
tasks=self._task_handlers.build_capability(),
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self._server_capabilities = result.capabilities
|
||||
|
||||
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
|
||||
|
||||
return result
|
||||
|
||||
def get_server_capabilities(self) -> types.ServerCapabilities | None:
|
||||
"""Return the server capabilities received during initialization.
|
||||
|
||||
Returns None if the session has not been initialized yet.
|
||||
"""
|
||||
return self._server_capabilities
|
||||
|
||||
@property
|
||||
def experimental(self) -> ExperimentalClientFeatures:
|
||||
"""Experimental APIs for tasks and other features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
if self._experimental_features is None:
|
||||
self._experimental_features = ExperimentalClientFeatures(self)
|
||||
return self._experimental_features
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.PingRequest()),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resources(self) -> types.ListResourcesResult: ...
|
||||
|
||||
async def list_resources(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resource_templates(
|
||||
self, *, params: types.PaginatedRequestParams | None
|
||||
) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
async def list_resource_templates(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request with optional progress callback support."""
|
||||
|
||||
_meta: types.RequestParams.Meta | None = None
|
||||
if meta is not None:
|
||||
_meta = types.RequestParams.Meta(**meta)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
if not result.isError:
|
||||
await self._validate_tool_result(name, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
|
||||
"""Validate the structured content of a tool result against its output schema."""
|
||||
if name not in self._tool_output_schemas:
|
||||
# refresh output schema cache
|
||||
await self.list_tools()
|
||||
|
||||
output_schema = None
|
||||
if name in self._tool_output_schemas:
|
||||
output_schema = self._tool_output_schemas.get(name)
|
||||
else:
|
||||
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
|
||||
|
||||
if output_schema is not None:
|
||||
from jsonschema import SchemaError, ValidationError, validate
|
||||
|
||||
if result.structuredContent is None:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} has an output schema but did not return structured content"
|
||||
) # pragma: no cover
|
||||
try:
|
||||
validate(result.structuredContent, output_schema)
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
|
||||
except SchemaError as e: # pragma: no cover
|
||||
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_prompts(self) -> types.ListPromptsResult: ...
|
||||
|
||||
async def list_prompts(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
context_arguments: dict[str, str] | None = None,
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
context = None
|
||||
if context_arguments is not None:
|
||||
context = types.CompletionContext(arguments=context_arguments)
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
context=context,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_tools(self) -> types.ListToolsResult: ...
|
||||
|
||||
async def list_tools(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListToolsResult:
|
||||
"""Send a tools/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(types.ListToolsRequest(params=request_params)),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
# Cache tool output schemas for future validation
|
||||
# Note: don't clear the cache, as we may be using a cursor
|
||||
for tool in result.tools:
|
||||
self._tool_output_schemas[tool.name] = tool.outputSchema
|
||||
|
||||
return result
|
||||
|
||||
async def send_roots_list_changed(self) -> None: # pragma: no cover
|
||||
"""Send a roots/list_changed notification."""
|
||||
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
|
||||
|
||||
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
# Delegate to experimental task handler if applicable
|
||||
if self._task_handlers.handles_request(responder.request):
|
||||
with responder:
|
||||
await self._task_handlers.handle_request(ctx, responder)
|
||||
return None
|
||||
|
||||
# Core request handling
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
# Check if this is a task-augmented request
|
||||
if params.task is not None:
|
||||
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
|
||||
else:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ElicitRequest(params=params):
|
||||
with responder:
|
||||
# Check if this is a task-augmented request
|
||||
if params.task is not None:
|
||||
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
|
||||
else:
|
||||
response = await self._elicitation_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest(): # pragma: no cover
|
||||
with responder:
|
||||
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
case _: # pragma: no cover
|
||||
pass # Task requests handled above by _task_handlers
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
await self._message_handler(req)
|
||||
|
||||
async def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
await self._logging_callback(params)
|
||||
case types.ElicitCompleteNotification(params=params):
|
||||
# Handle elicitation completion notification
|
||||
# Clients MAY use this to retry requests or update UI
|
||||
# The notification contains the elicitationId of the completed elicitation
|
||||
pass
|
||||
case _:
|
||||
pass
|
||||
447
.venv/lib/python3.11/site-packages/mcp/client/session_group.py
Normal file
447
.venv/lib/python3.11/site-packages/mcp/client/session_group.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
SessionGroup concurrently manages multiple MCP session connections.
|
||||
|
||||
Tools, resources, and prompts are aggregated across servers. Servers may
|
||||
be connected to or disconnected from at any point after initialization.
|
||||
|
||||
This abstractions can handle naming collisions using a custom user-provided
|
||||
hook.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, TypeAlias, overload
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
import mcp
|
||||
from mcp import types
|
||||
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
from mcp.shared._httpx_utils import create_mcp_http_client
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import ProgressFnT
|
||||
|
||||
|
||||
class SseServerParameters(BaseModel):
|
||||
"""Parameters for intializing a sse_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: float = 5
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class StreamableHttpParameters(BaseModel):
|
||||
"""Parameters for intializing a streamable_http_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: timedelta = timedelta(seconds=30)
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
|
||||
|
||||
# Close the client session when the transport closes.
|
||||
terminate_on_close: bool = True
|
||||
|
||||
|
||||
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
|
||||
|
||||
# Use dataclass instead of pydantic BaseModel
|
||||
# because pydantic BaseModel cannot handle Protocol fields.
|
||||
@dataclass
|
||||
class ClientSessionParameters:
|
||||
"""Parameters for establishing a client session to an MCP server."""
|
||||
|
||||
read_timeout_seconds: timedelta | None = None
|
||||
sampling_callback: SamplingFnT | None = None
|
||||
elicitation_callback: ElicitationFnT | None = None
|
||||
list_roots_callback: ListRootsFnT | None = None
|
||||
logging_callback: LoggingFnT | None = None
|
||||
message_handler: MessageHandlerFnT | None = None
|
||||
client_info: types.Implementation | None = None
|
||||
|
||||
|
||||
class ClientSessionGroup:
|
||||
"""Client for managing connections to multiple MCP servers.
|
||||
|
||||
This class is responsible for encapsulating management of server connections.
|
||||
It aggregates tools, resources, and prompts from all connected servers.
|
||||
|
||||
For auxiliary handlers, such as resource subscription, this is delegated to
|
||||
the client and can be accessed via the session.
|
||||
|
||||
Example Usage:
|
||||
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
|
||||
async with ClientSessionGroup(component_name_hook=name_fn) as group:
|
||||
for server_param in server_params:
|
||||
await group.connect_to_server(server_param)
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
class _ComponentNames(BaseModel):
|
||||
"""Used for reverse index to find components."""
|
||||
|
||||
prompts: set[str] = set()
|
||||
resources: set[str] = set()
|
||||
tools: set[str] = set()
|
||||
|
||||
# Standard MCP components.
|
||||
_prompts: dict[str, types.Prompt]
|
||||
_resources: dict[str, types.Resource]
|
||||
_tools: dict[str, types.Tool]
|
||||
|
||||
# Client-server connection management.
|
||||
_sessions: dict[mcp.ClientSession, _ComponentNames]
|
||||
_tool_to_session: dict[str, mcp.ClientSession]
|
||||
_exit_stack: contextlib.AsyncExitStack
|
||||
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
|
||||
|
||||
# Optional fn consuming (component_name, serverInfo) for custom names.
|
||||
# This is provide a means to mitigate naming conflicts across servers.
|
||||
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
|
||||
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
|
||||
_component_name_hook: _ComponentNameHook | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exit_stack: contextlib.AsyncExitStack | None = None,
|
||||
component_name_hook: _ComponentNameHook | None = None,
|
||||
) -> None:
|
||||
"""Initializes the MCP client."""
|
||||
|
||||
self._tools = {}
|
||||
self._resources = {}
|
||||
self._prompts = {}
|
||||
|
||||
self._sessions = {}
|
||||
self._tool_to_session = {}
|
||||
if exit_stack is None:
|
||||
self._exit_stack = contextlib.AsyncExitStack()
|
||||
self._owns_exit_stack = True
|
||||
else:
|
||||
self._exit_stack = exit_stack
|
||||
self._owns_exit_stack = False
|
||||
self._session_exit_stacks = {}
|
||||
self._component_name_hook = component_name_hook
|
||||
|
||||
async def __aenter__(self) -> Self: # pragma: no cover
|
||||
# Enter the exit stack only if we created it ourselves
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
_exc_type: type[BaseException] | None,
|
||||
_exc_val: BaseException | None,
|
||||
_exc_tb: TracebackType | None,
|
||||
) -> bool | None: # pragma: no cover
|
||||
"""Closes session exit stacks and main exit stack upon completion."""
|
||||
|
||||
# Only close the main exit stack if we created it
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
# Concurrently close session stacks.
|
||||
async with anyio.create_task_group() as tg:
|
||||
for exit_stack in self._session_exit_stacks.values():
|
||||
tg.start_soon(exit_stack.aclose)
|
||||
|
||||
@property
|
||||
def sessions(self) -> list[mcp.ClientSession]:
|
||||
"""Returns the list of sessions being managed."""
|
||||
return list(self._sessions.keys()) # pragma: no cover
|
||||
|
||||
@property
|
||||
def prompts(self) -> dict[str, types.Prompt]:
|
||||
"""Returns the prompts as a dictionary of names to prompts."""
|
||||
return self._prompts
|
||||
|
||||
@property
|
||||
def resources(self) -> dict[str, types.Resource]:
|
||||
"""Returns the resources as a dictionary of names to resources."""
|
||||
return self._resources
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, types.Tool]:
|
||||
"""Returns the tools as a dictionary of names to tools."""
|
||||
return self._tools
|
||||
|
||||
@overload
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult: ...
|
||||
|
||||
@overload
|
||||
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult: ...
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Executes a tool given its name and arguments."""
|
||||
session = self._tool_to_session[name]
|
||||
session_tool_name = self.tools[name].name
|
||||
return await session.call_tool(
|
||||
session_tool_name,
|
||||
arguments if args is None else args,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
progress_callback=progress_callback,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
|
||||
"""Disconnects from a single MCP server."""
|
||||
|
||||
session_known_for_components = session in self._sessions
|
||||
session_known_for_stack = session in self._session_exit_stacks
|
||||
|
||||
if not session_known_for_components and not session_known_for_stack:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message="Provided session is not managed or already disconnected.",
|
||||
)
|
||||
)
|
||||
|
||||
if session_known_for_components: # pragma: no cover
|
||||
component_names = self._sessions.pop(session) # Pop from _sessions tracking
|
||||
|
||||
# Remove prompts associated with the session.
|
||||
for name in component_names.prompts:
|
||||
if name in self._prompts:
|
||||
del self._prompts[name]
|
||||
# Remove resources associated with the session.
|
||||
for name in component_names.resources:
|
||||
if name in self._resources:
|
||||
del self._resources[name]
|
||||
# Remove tools associated with the session.
|
||||
for name in component_names.tools:
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
if name in self._tool_to_session:
|
||||
del self._tool_to_session[name]
|
||||
|
||||
# Clean up the session's resources via its dedicated exit stack
|
||||
if session_known_for_stack:
|
||||
session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover
|
||||
await session_stack_to_close.aclose() # pragma: no cover
|
||||
|
||||
async def connect_with_session(
|
||||
self, server_info: types.Implementation, session: mcp.ClientSession
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
await self._aggregate_components(server_info, session)
|
||||
return session
|
||||
|
||||
async def connect_to_server(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
session_params: ClientSessionParameters | None = None,
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
|
||||
return await self.connect_with_session(server_info, session)
|
||||
|
||||
async def _establish_session(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
session_params: ClientSessionParameters,
|
||||
) -> tuple[types.Implementation, mcp.ClientSession]:
|
||||
"""Establish a client session to an MCP server."""
|
||||
|
||||
session_stack = contextlib.AsyncExitStack()
|
||||
try:
|
||||
# Create read and write streams that facilitate io with the server.
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
client = mcp.stdio_client(server_params)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
elif isinstance(server_params, SseServerParameters):
|
||||
client = sse_client(
|
||||
url=server_params.url,
|
||||
headers=server_params.headers,
|
||||
timeout=server_params.timeout,
|
||||
sse_read_timeout=server_params.sse_read_timeout,
|
||||
)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
else:
|
||||
httpx_client = create_mcp_http_client(
|
||||
headers=server_params.headers,
|
||||
timeout=httpx.Timeout(
|
||||
server_params.timeout.total_seconds(),
|
||||
read=server_params.sse_read_timeout.total_seconds(),
|
||||
),
|
||||
)
|
||||
await session_stack.enter_async_context(httpx_client)
|
||||
|
||||
client = streamable_http_client(
|
||||
url=server_params.url,
|
||||
http_client=httpx_client,
|
||||
terminate_on_close=server_params.terminate_on_close,
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(client)
|
||||
|
||||
session = await session_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read,
|
||||
write,
|
||||
read_timeout_seconds=session_params.read_timeout_seconds,
|
||||
sampling_callback=session_params.sampling_callback,
|
||||
elicitation_callback=session_params.elicitation_callback,
|
||||
list_roots_callback=session_params.list_roots_callback,
|
||||
logging_callback=session_params.logging_callback,
|
||||
message_handler=session_params.message_handler,
|
||||
client_info=session_params.client_info,
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.initialize()
|
||||
|
||||
# Session successfully initialized.
|
||||
# Store its stack and register the stack with the main group stack.
|
||||
self._session_exit_stacks[session] = session_stack
|
||||
# session_stack itself becomes a resource managed by the
|
||||
# main _exit_stack.
|
||||
await self._exit_stack.enter_async_context(session_stack)
|
||||
|
||||
return result.serverInfo, session
|
||||
except Exception: # pragma: no cover
|
||||
# If anything during this setup fails, ensure the session-specific
|
||||
# stack is closed.
|
||||
await session_stack.aclose()
|
||||
raise
|
||||
|
||||
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
|
||||
"""Aggregates prompts, resources, and tools from a given session."""
|
||||
|
||||
# Create a reverse index so we can find all prompts, resources, and
|
||||
# tools belonging to this session. Used for removing components from
|
||||
# the session group via self.disconnect_from_server.
|
||||
component_names = self._ComponentNames()
|
||||
|
||||
# Temporary components dicts. We do not want to modify the aggregate
|
||||
# lists in case of an intermediate failure.
|
||||
prompts_temp: dict[str, types.Prompt] = {}
|
||||
resources_temp: dict[str, types.Resource] = {}
|
||||
tools_temp: dict[str, types.Tool] = {}
|
||||
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
|
||||
|
||||
# Query the server for its prompts and aggregate to list.
|
||||
try:
|
||||
prompts = (await session.list_prompts()).prompts
|
||||
for prompt in prompts:
|
||||
name = self._component_name(prompt.name, server_info)
|
||||
prompts_temp[name] = prompt
|
||||
component_names.prompts.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch prompts: {err}")
|
||||
|
||||
# Query the server for its resources and aggregate to list.
|
||||
try:
|
||||
resources = (await session.list_resources()).resources
|
||||
for resource in resources:
|
||||
name = self._component_name(resource.name, server_info)
|
||||
resources_temp[name] = resource
|
||||
component_names.resources.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch resources: {err}")
|
||||
|
||||
# Query the server for its tools and aggregate to list.
|
||||
try:
|
||||
tools = (await session.list_tools()).tools
|
||||
for tool in tools:
|
||||
name = self._component_name(tool.name, server_info)
|
||||
tools_temp[name] = tool
|
||||
tool_to_session_temp[name] = session
|
||||
component_names.tools.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch tools: {err}")
|
||||
|
||||
# Clean up exit stack for session if we couldn't retrieve anything
|
||||
# from the server.
|
||||
if not any((prompts_temp, resources_temp, tools_temp)):
|
||||
del self._session_exit_stacks[session] # pragma: no cover
|
||||
|
||||
# Check for duplicates.
|
||||
matching_prompts = prompts_temp.keys() & self._prompts.keys()
|
||||
if matching_prompts:
|
||||
raise McpError( # pragma: no cover
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_prompts} already exist in group prompts.",
|
||||
)
|
||||
)
|
||||
matching_resources = resources_temp.keys() & self._resources.keys()
|
||||
if matching_resources:
|
||||
raise McpError( # pragma: no cover
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_resources} already exist in group resources.",
|
||||
)
|
||||
)
|
||||
matching_tools = tools_temp.keys() & self._tools.keys()
|
||||
if matching_tools:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_tools} already exist in group tools.",
|
||||
)
|
||||
)
|
||||
|
||||
# Aggregate components.
|
||||
self._sessions[session] = component_names
|
||||
self._prompts.update(prompts_temp)
|
||||
self._resources.update(resources_temp)
|
||||
self._tools.update(tools_temp)
|
||||
self._tool_to_session.update(tool_to_session_temp)
|
||||
|
||||
def _component_name(self, name: str, server_info: types.Implementation) -> str:
|
||||
if self._component_name_hook:
|
||||
return self._component_name_hook(name, server_info)
|
||||
return name
|
||||
164
.venv/lib/python3.11/site-packages/mcp/client/sse.py
Normal file
164
.venv/lib/python3.11/site-packages/mcp/client/sse.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
from httpx_sse._exceptions import SSEError
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
|
||||
query_params = parse_qs(urlparse(endpoint_url).query)
|
||||
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||
auth: httpx.Auth | None = None,
|
||||
on_session_created: Callable[[str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
on_session_created: Optional callback invoked with the session ID when received.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx_client_factory(
|
||||
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
|
||||
) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.debug(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if ( # pragma: no cover
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme != endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = ( # pragma: no cover
|
||||
f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg) # pragma: no cover
|
||||
raise ValueError(error_msg) # pragma: no cover
|
||||
|
||||
if on_session_created:
|
||||
session_id = _extract_session_id_from_endpoint(endpoint_url)
|
||||
if session_id:
|
||||
on_session_created(session_id)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
# Skip empty data (keep-alive pings)
|
||||
if not sse.data:
|
||||
continue
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||
sse.data
|
||||
)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing server message") # pragma: no cover
|
||||
await read_stream_writer.send(exc) # pragma: no cover
|
||||
continue # pragma: no cover
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
case _: # pragma: no cover
|
||||
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
|
||||
except SSEError as sse_exc: # pragma: no cover
|
||||
logger.exception("Encountered SSE exception") # pragma: no cover
|
||||
raise sse_exc # pragma: no cover
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error in sse_reader") # pragma: no cover
|
||||
await read_stream_writer.send(exc) # pragma: no cover
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {session_message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Error in post_writer") # pragma: no cover
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
278
.venv/lib/python3.11/site-packages/mcp/client/stdio/__init__.py
Normal file
278
.venv/lib/python3.11/site-packages/mcp/client/stdio/__init__.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal, TextIO
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.abc import Process
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.os.posix.utilities import terminate_posix_process_tree
|
||||
from mcp.os.win32.utilities import (
|
||||
FallbackProcess,
|
||||
create_windows_process,
|
||||
get_windows_executable_command,
|
||||
terminate_windows_process_tree,
|
||||
)
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
[
|
||||
"APPDATA",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
"LOCALAPPDATA",
|
||||
"PATH",
|
||||
"PATHEXT",
|
||||
"PROCESSOR_ARCHITECTURE",
|
||||
"SYSTEMDRIVE",
|
||||
"SYSTEMROOT",
|
||||
"TEMP",
|
||||
"USERNAME",
|
||||
"USERPROFILE",
|
||||
]
|
||||
if sys.platform == "win32"
|
||||
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
||||
)
|
||||
|
||||
# Timeout for process termination before falling back to force kill
|
||||
PROCESS_TERMINATION_TIMEOUT = 2.0
|
||||
|
||||
|
||||
def get_default_environment() -> dict[str, str]:
|
||||
"""
|
||||
Returns a default environment object including only environment variables deemed
|
||||
safe to inherit.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
for key in DEFAULT_INHERITED_ENV_VARS:
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
continue # pragma: no cover
|
||||
|
||||
if value.startswith("()"): # pragma: no cover
|
||||
# Skip functions, which are a security risk
|
||||
continue # pragma: no cover
|
||||
|
||||
env[key] = value
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None = None
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
If not specified, the result of get_default_environment() will be used.
|
||||
"""
|
||||
|
||||
cwd: str | Path | None = None
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
"""
|
||||
The text encoding used when sending/receiving messages to the server
|
||||
|
||||
defaults to utf-8
|
||||
"""
|
||||
|
||||
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
|
||||
"""
|
||||
The text encoding error handler.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
try:
|
||||
command = _get_executable_command(server.command)
|
||||
|
||||
# Open process with stderr piped for capture
|
||||
process = await _create_platform_compatible_process(
|
||||
command=command,
|
||||
args=server.args,
|
||||
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
|
||||
errlog=errlog,
|
||||
cwd=server.cwd,
|
||||
)
|
||||
except OSError:
|
||||
# Clean up streams if process creation fails
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
raise
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Failed to parse JSONRPC message from server")
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError: # pragma: no cover
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError: # pragma: no cover
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
# MCP spec: stdio shutdown sequence
|
||||
# 1. Close input stream to server
|
||||
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
|
||||
# 3. Send SIGKILL if still not exited
|
||||
if process.stdin: # pragma: no branch
|
||||
try:
|
||||
await process.stdin.aclose()
|
||||
except Exception: # pragma: no cover
|
||||
# stdin might already be closed, which is fine
|
||||
pass
|
||||
|
||||
try:
|
||||
# Give the process time to exit gracefully after stdin closes
|
||||
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
|
||||
await process.wait()
|
||||
except TimeoutError:
|
||||
# Process didn't exit from stdin closure, use platform-specific termination
|
||||
# which handles SIGTERM -> SIGKILL escalation
|
||||
await _terminate_process_tree(process)
|
||||
except ProcessLookupError: # pragma: no cover
|
||||
# Process already exited, which is fine
|
||||
pass
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
|
||||
|
||||
def _get_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for the current platform.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Platform-appropriate command
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
return get_windows_executable_command(command)
|
||||
else:
|
||||
return command # pragma: no cover
|
||||
|
||||
|
||||
async def _create_platform_compatible_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a platform-compatible way.
|
||||
|
||||
Unix: Creates process in a new session/process group for killpg support
|
||||
Windows: Creates process in a Job Object for reliable child termination
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
process = await create_windows_process(command, args, env, errlog, cwd)
|
||||
else:
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
start_new_session=True,
|
||||
) # pragma: no cover
|
||||
|
||||
return process
|
||||
|
||||
|
||||
async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
|
||||
"""
|
||||
Terminate a process and all its children using platform-specific methods.
|
||||
|
||||
Unix: Uses os.killpg() for atomic process group termination
|
||||
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
await terminate_windows_process_tree(process, timeout_seconds)
|
||||
else: # pragma: no cover
|
||||
# FallbackProcess should only be used for Windows compatibility
|
||||
assert isinstance(process, Process)
|
||||
await terminate_posix_process_tree(process, timeout_seconds)
|
||||
Binary file not shown.
722
.venv/lib/python3.11/site-packages/mcp/client/streamable_http.py
Normal file
722
.venv/lib/python3.11/site-packages/mcp/client/streamable_http.py
Normal file
@@ -0,0 +1,722 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, overload
|
||||
from warnings import warn
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from mcp.shared._httpx_utils import (
|
||||
McpHttpClientFactory,
|
||||
create_mcp_http_client,
|
||||
)
|
||||
from mcp.shared.message import ClientMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception
|
||||
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
|
||||
StreamReader = MemoryObjectReceiveStream[SessionMessage]
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
|
||||
# Reconnection defaults
|
||||
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
|
||||
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
# Sentinel value for detecting unset optional parameters
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.AsyncClient
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
read_stream_writer: StreamWriter
|
||||
headers: dict[str, str] | None = None # Deprecated - no longer used
|
||||
sse_read_timeout: float | None = None # Deprecated - no longer used
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
@overload
|
||||
def __init__(self, url: str) -> None: ...
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
|
||||
"Configure these on the httpx.AsyncClient instead."
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Any = _UNSET,
|
||||
timeout: Any = _UNSET,
|
||||
sse_read_timeout: Any = _UNSET,
|
||||
auth: Any = _UNSET,
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
"""
|
||||
# Check for deprecated parameters and issue runtime warning
|
||||
deprecated_params: list[str] = []
|
||||
if headers is not _UNSET:
|
||||
deprecated_params.append("headers")
|
||||
if timeout is not _UNSET:
|
||||
deprecated_params.append("timeout")
|
||||
if sse_read_timeout is not _UNSET:
|
||||
deprecated_params.append("sse_read_timeout")
|
||||
if auth is not _UNSET:
|
||||
deprecated_params.append("auth")
|
||||
|
||||
if deprecated_params:
|
||||
warn(
|
||||
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
|
||||
"Configure these on the httpx.AsyncClient instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.session_id = None
|
||||
self.protocol_version = None
|
||||
|
||||
def _prepare_headers(self) -> dict[str, str]:
|
||||
"""Build MCP-specific request headers.
|
||||
|
||||
These headers will be merged with the httpx.AsyncClient's default headers,
|
||||
with these MCP-specific headers taking precedence.
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
# Add MCP protocol headers
|
||||
headers[ACCEPT] = f"{JSON}, {SSE}"
|
||||
headers[CONTENT_TYPE] = JSON
|
||||
# Add session headers if available
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
if self.protocol_version:
|
||||
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _maybe_extract_protocol_version_from_message(
|
||||
self,
|
||||
message: JSONRPCMessage,
|
||||
) -> None:
|
||||
"""Extract protocol version from initialization response message."""
|
||||
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
|
||||
try:
|
||||
# Parse the result as InitializeResult for type safety
|
||||
init_result = InitializeResult.model_validate(message.root.result)
|
||||
self.protocol_version = str(init_result.protocolVersion)
|
||||
logger.info(f"Negotiated protocol version: {self.protocol_version}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning(
|
||||
f"Failed to parse initialization response as InitializeResult: {exc}"
|
||||
) # pragma: no cover
|
||||
logger.warning(f"Raw result: {message.root.result}")
|
||||
|
||||
async def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
read_stream_writer: StreamWriter,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
||||
is_initialization: bool = False,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
# Handle priming events (empty data with ID) for resumability
|
||||
if not sse.data:
|
||||
# Call resumption callback for priming events that have an ID
|
||||
if sse.id and resumption_callback:
|
||||
await resumption_callback(sse.id)
|
||||
return False
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# Extract protocol version from initialization response
|
||||
if is_initialization:
|
||||
self._maybe_extract_protocol_version_from_message(message)
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
await resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing SSE message")
|
||||
await read_stream_writer.send(exc)
|
||||
return False
|
||||
else: # pragma: no cover
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
async def handle_get_stream(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages with auto-reconnect."""
|
||||
last_event_id: str | None = None
|
||||
retry_interval_ms: int | None = None
|
||||
attempt: int = 0
|
||||
|
||||
while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._prepare_headers()
|
||||
if last_event_id:
|
||||
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover
|
||||
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
# Track last event ID for reconnection
|
||||
if sse.id:
|
||||
last_event_id = sse.id # pragma: no cover
|
||||
# Track retry interval from server
|
||||
if sse.retry is not None:
|
||||
retry_interval_ms = sse.retry # pragma: no cover
|
||||
|
||||
await self._handle_sse_event(sse, read_stream_writer)
|
||||
|
||||
# Stream ended normally (server closed) - reset attempt counter
|
||||
attempt = 0
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.debug(f"GET stream error: {exc}")
|
||||
attempt += 1
|
||||
|
||||
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
|
||||
logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
|
||||
return
|
||||
|
||||
# Wait before reconnecting
|
||||
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
|
||||
logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...")
|
||||
await anyio.sleep(delay_ms / 1000.0)
|
||||
|
||||
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._prepare_headers()
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
async with aconnect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
await event_source.response.aclose()
|
||||
break
|
||||
|
||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._prepare_headers()
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
async with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404: # pragma: no branch
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
await self._send_session_terminated_error( # pragma: no cover
|
||||
ctx.read_stream_writer, # pragma: no cover
|
||||
message.root.id, # pragma: no cover
|
||||
) # pragma: no cover
|
||||
return # pragma: no cover
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||
# The server MUST NOT send a response to notifications.
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
content_type = response.headers.get(CONTENT_TYPE, "").lower()
|
||||
if content_type.startswith(JSON):
|
||||
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
|
||||
elif content_type.startswith(SSE):
|
||||
await self._handle_sse_response(response, ctx, is_initialization)
|
||||
else:
|
||||
await self._handle_unexpected_content_type( # pragma: no cover
|
||||
content_type, # pragma: no cover
|
||||
ctx.read_stream_writer, # pragma: no cover
|
||||
) # pragma: no cover
|
||||
|
||||
async def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
read_stream_writer: StreamWriter,
|
||||
is_initialization: bool = False,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = await response.aread()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
|
||||
# Extract protocol version from initialization response
|
||||
if is_initialization:
|
||||
self._maybe_extract_protocol_version_from_message(message)
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing JSON response")
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def _handle_sse_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
ctx: RequestContext,
|
||||
is_initialization: bool = False,
|
||||
) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
last_event_id: str | None = None
|
||||
retry_interval_ms: int | None = None
|
||||
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
# Track last event ID for potential reconnection
|
||||
if sse.id:
|
||||
last_event_id = sse.id
|
||||
|
||||
# Track retry interval from server
|
||||
if sse.retry is not None:
|
||||
retry_interval_ms = sse.retry
|
||||
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
is_initialization=is_initialization,
|
||||
)
|
||||
# If the SSE event indicates completion, like returning respose/error
|
||||
# break the loop
|
||||
if is_complete:
|
||||
await response.aclose()
|
||||
return # Normal completion, no reconnect needed
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.debug(f"SSE stream ended: {e}")
|
||||
|
||||
# Stream ended without response - reconnect if we received an event with ID
|
||||
if last_event_id is not None: # pragma: no branch
|
||||
logger.info("SSE stream disconnected, reconnecting...")
|
||||
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
|
||||
|
||||
async def _handle_reconnection(
|
||||
self,
|
||||
ctx: RequestContext,
|
||||
last_event_id: str,
|
||||
retry_interval_ms: int | None = None,
|
||||
attempt: int = 0,
|
||||
) -> None:
|
||||
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
|
||||
# Bail if max retries exceeded
|
||||
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
|
||||
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
|
||||
return
|
||||
|
||||
# Always wait - use server value or default
|
||||
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
|
||||
await anyio.sleep(delay_ms / 1000.0)
|
||||
|
||||
headers = self._prepare_headers()
|
||||
headers[LAST_EVENT_ID] = last_event_id
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.info("Reconnected to SSE stream")
|
||||
|
||||
# Track for potential further reconnection
|
||||
reconnect_last_event_id: str = last_event_id
|
||||
reconnect_retry_ms = retry_interval_ms
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.id: # pragma: no branch
|
||||
reconnect_last_event_id = sse.id
|
||||
if sse.retry is not None:
|
||||
reconnect_retry_ms = sse.retry
|
||||
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
await event_source.response.aclose()
|
||||
return
|
||||
|
||||
# Stream ended again without response - reconnect again (reset attempt counter)
|
||||
logger.info("SSE stream disconnected, reconnecting...")
|
||||
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.debug(f"Reconnection failed: {e}")
|
||||
# Try to reconnect again if we still have an event ID
|
||||
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
|
||||
|
||||
async def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None: # pragma: no cover
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
|
||||
logger.error(error_msg) # pragma: no cover
|
||||
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
|
||||
|
||||
async def _send_session_terminated_error(
|
||||
self,
|
||||
read_stream_writer: StreamWriter,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
async def post_writer(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
write_stream_reader: StreamReader,
|
||||
read_stream_writer: StreamWriter,
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
start_get_stream: Callable[[], None],
|
||||
tg: TaskGroup,
|
||||
) -> None:
|
||||
"""Handle writing requests to the server."""
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata
|
||||
if isinstance(session_message.metadata, ClientMessageMetadata)
|
||||
else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
read_stream_writer=read_stream_writer,
|
||||
)
|
||||
|
||||
async def handle_request_async():
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
|
||||
# If this is a request, start a new task to handle it
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
tg.start_soon(handle_request_async)
|
||||
else:
|
||||
await handle_request_async()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in post_writer") # pragma: no cover
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._prepare_headers()
|
||||
response = await client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code not in (200, 204):
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def streamable_http_client(
|
||||
url: str,
|
||||
*,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
terminate_on_close: bool = True,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
Args:
|
||||
url: The MCP server endpoint URL.
|
||||
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
|
||||
client with recommended MCP timeouts will be created. To configure headers,
|
||||
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
|
||||
terminate_on_close: If True, send a DELETE request to terminate the session
|
||||
when the context exits.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- read_stream: Stream for reading messages from the server
|
||||
- write_stream: Stream for sending messages to the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
|
||||
Example:
|
||||
See examples/snippets/clients/ for usage patterns.
|
||||
"""
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
||||
|
||||
# Determine if we need to create and manage the client
|
||||
client_provided = http_client is not None
|
||||
client = http_client
|
||||
|
||||
if client is None:
|
||||
# Create default client with recommended MCP timeouts
|
||||
client = create_mcp_http_client()
|
||||
|
||||
transport = StreamableHTTPTransport(url)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
|
||||
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
# Only manage client lifecycle if we created it
|
||||
if not client_provided:
|
||||
await stack.enter_async_context(client)
|
||||
|
||||
def start_get_stream() -> None:
|
||||
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
|
||||
|
||||
tg.start_soon(
|
||||
transport.post_writer,
|
||||
client,
|
||||
write_stream_reader,
|
||||
read_stream_writer,
|
||||
write_stream,
|
||||
start_get_stream,
|
||||
tg,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
read_stream,
|
||||
write_stream,
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
await transport.terminate_session(client)
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@deprecated("Use `streamable_http_client` instead.")
|
||||
async def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
terminate_on_close: bool = True,
|
||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
# Convert timeout parameters
|
||||
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||
sse_read_timeout_seconds = (
|
||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||
)
|
||||
|
||||
# Create httpx client using the factory with old-style parameters
|
||||
client = httpx_client_factory(
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
# Manage client lifecycle since we created it
|
||||
async with client:
|
||||
async with streamable_http_client(
|
||||
url,
|
||||
http_client=client,
|
||||
terminate_on_close=terminate_on_close,
|
||||
) as streams:
|
||||
yield streams
|
||||
86
.venv/lib/python3.11/site-packages/mcp/client/websocket.py
Normal file
86
.venv/lib/python3.11/site-packages/mcp/client/websocket.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_client(
|
||||
url: str,
|
||||
) -> AsyncGenerator[
|
||||
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
WebSocket client transport for MCP, symmetrical to the server version.
|
||||
|
||||
Connects to 'url' using the 'mcp' subprotocol, then yields:
|
||||
(read_stream, write_stream)
|
||||
|
||||
- read_stream: As you read from this stream, you'll receive either valid
|
||||
JSONRPCMessage objects or Exception objects (when validation fails).
|
||||
- write_stream: Write JSONRPCMessage objects to this stream to send them
|
||||
over the WebSocket to the server.
|
||||
"""
|
||||
|
||||
# Create two in-memory streams:
|
||||
# - One for incoming messages (read_stream, written by ws_reader)
|
||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
# Connect using websockets, requesting the "mcp" subprotocol
|
||||
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
|
||||
|
||||
async def ws_reader():
|
||||
"""
|
||||
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
|
||||
and sends them into read_stream_writer.
|
||||
"""
|
||||
async with read_stream_writer:
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except ValidationError as exc: # pragma: no cover
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def ws_writer():
|
||||
"""
|
||||
Reads JSON-RPC messages from write_stream_reader and
|
||||
sends them to the server.
|
||||
"""
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
# Convert to a dict, then to JSON
|
||||
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
await ws.send(json.dumps(msg_dict))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start reader and writer tasks
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
|
||||
# Yield the receive/send streams
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
# Once the caller's 'async with' block exits, we shut down
|
||||
tg.cancel_scope.cancel()
|
||||
Reference in New Issue
Block a user