Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,87 @@
|
||||
"""Utilities for creating standardized httpx AsyncClient instances."""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
import httpx
|
||||
|
||||
__all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"]
|
||||
|
||||
# Default MCP timeout configuration
|
||||
MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds)
|
||||
MCP_DEFAULT_SSE_READ_TIMEOUT = 300.0 # SSE streams - 5 minutes (seconds)
|
||||
|
||||
|
||||
class McpHttpClientFactory(Protocol): # pragma: no branch
|
||||
def __call__( # pragma: no branch
|
||||
self,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient: ...
|
||||
|
||||
|
||||
def create_mcp_http_client(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create a standardized httpx AsyncClient with MCP defaults.
|
||||
|
||||
This function provides common defaults used throughout the MCP codebase:
|
||||
- follow_redirects=True (always enabled)
|
||||
- Default timeout of 30 seconds if not specified
|
||||
|
||||
Args:
|
||||
headers: Optional headers to include with all requests.
|
||||
timeout: Request timeout as httpx.Timeout object.
|
||||
Defaults to 30 seconds if not specified.
|
||||
auth: Optional authentication handler.
|
||||
|
||||
Returns:
|
||||
Configured httpx.AsyncClient instance with MCP defaults.
|
||||
|
||||
Note:
|
||||
The returned AsyncClient must be used as a context manager to ensure
|
||||
proper cleanup of connections.
|
||||
|
||||
Examples:
|
||||
# Basic usage with MCP defaults
|
||||
async with create_mcp_http_client() as client:
|
||||
response = await client.get("https://api.example.com")
|
||||
|
||||
# With custom headers
|
||||
headers = {"Authorization": "Bearer token"}
|
||||
async with create_mcp_http_client(headers) as client:
|
||||
response = await client.get("/endpoint")
|
||||
|
||||
# With both custom headers and timeout
|
||||
timeout = httpx.Timeout(60.0, read=300.0)
|
||||
async with create_mcp_http_client(headers, timeout) as client:
|
||||
response = await client.get("/long-request")
|
||||
|
||||
# With authentication
|
||||
from httpx import BasicAuth
|
||||
auth = BasicAuth(username="user", password="pass")
|
||||
async with create_mcp_http_client(headers, timeout, auth) as client:
|
||||
response = await client.get("/protected-endpoint")
|
||||
"""
|
||||
# Set MCP defaults
|
||||
kwargs: dict[str, Any] = {
|
||||
"follow_redirects": True,
|
||||
}
|
||||
|
||||
# Handle timeout
|
||||
if timeout is None:
|
||||
kwargs["timeout"] = httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT)
|
||||
else:
|
||||
kwargs["timeout"] = timeout
|
||||
|
||||
# Handle headers
|
||||
if headers is not None:
|
||||
kwargs["headers"] = headers
|
||||
|
||||
# Handle authentication
|
||||
if auth is not None: # pragma: no cover
|
||||
kwargs["auth"] = auth
|
||||
|
||||
return httpx.AsyncClient(**kwargs)
|
||||
159
.venv/lib/python3.11/site-packages/mcp/shared/auth.py
Normal file
159
.venv/lib/python3.11/site-packages/mcp/shared/auth.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class OAuthToken(BaseModel):
|
||||
"""
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
|
||||
"""
|
||||
|
||||
access_token: str
|
||||
token_type: Literal["Bearer"] = "Bearer"
|
||||
expires_in: int | None = None
|
||||
scope: str | None = None
|
||||
refresh_token: str | None = None
|
||||
|
||||
@field_validator("token_type", mode="before")
|
||||
@classmethod
|
||||
def normalize_token_type(cls, v: str | None) -> str | None:
|
||||
if isinstance(v, str):
|
||||
# Bearer is title-cased in the spec, so we normalize it
|
||||
# https://datatracker.ietf.org/doc/html/rfc6750#section-4
|
||||
return v.title()
|
||||
return v # pragma: no cover
|
||||
|
||||
|
||||
class InvalidScopeError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
class InvalidRedirectUriError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
class OAuthClientMetadata(BaseModel):
|
||||
"""
|
||||
RFC 7591 OAuth 2.0 Dynamic Client Registration metadata.
|
||||
See https://datatracker.ietf.org/doc/html/rfc7591#section-2
|
||||
for the full specification.
|
||||
"""
|
||||
|
||||
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
|
||||
# supported auth methods for the token endpoint
|
||||
token_endpoint_auth_method: (
|
||||
Literal["none", "client_secret_post", "client_secret_basic", "private_key_jwt"] | None
|
||||
) = None
|
||||
# supported grant_types of this implementation
|
||||
grant_types: list[
|
||||
Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str
|
||||
] = [
|
||||
"authorization_code",
|
||||
"refresh_token",
|
||||
]
|
||||
# The MCP spec requires the "code" response type, but OAuth
|
||||
# servers may also return additional types they support
|
||||
response_types: list[str] = ["code"]
|
||||
scope: str | None = None
|
||||
|
||||
# these fields are currently unused, but we support & store them for potential
|
||||
# future use
|
||||
client_name: str | None = None
|
||||
client_uri: AnyHttpUrl | None = None
|
||||
logo_uri: AnyHttpUrl | None = None
|
||||
contacts: list[str] | None = None
|
||||
tos_uri: AnyHttpUrl | None = None
|
||||
policy_uri: AnyHttpUrl | None = None
|
||||
jwks_uri: AnyHttpUrl | None = None
|
||||
jwks: Any | None = None
|
||||
software_id: str | None = None
|
||||
software_version: str | None = None
|
||||
|
||||
def validate_scope(self, requested_scope: str | None) -> list[str] | None:
|
||||
if requested_scope is None:
|
||||
return None
|
||||
requested_scopes = requested_scope.split(" ")
|
||||
allowed_scopes = [] if self.scope is None else self.scope.split(" ")
|
||||
for scope in requested_scopes:
|
||||
if scope not in allowed_scopes: # pragma: no branch
|
||||
raise InvalidScopeError(f"Client was not registered with scope {scope}")
|
||||
return requested_scopes # pragma: no cover
|
||||
|
||||
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
|
||||
if redirect_uri is not None:
|
||||
# Validate redirect_uri against client's registered redirect URIs
|
||||
if self.redirect_uris is None or redirect_uri not in self.redirect_uris:
|
||||
raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client")
|
||||
return redirect_uri
|
||||
elif self.redirect_uris is not None and len(self.redirect_uris) == 1:
|
||||
return self.redirect_uris[0]
|
||||
else:
|
||||
raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs")
|
||||
|
||||
|
||||
class OAuthClientInformationFull(OAuthClientMetadata):
|
||||
"""
|
||||
RFC 7591 OAuth 2.0 Dynamic Client Registration full response
|
||||
(client information plus metadata).
|
||||
"""
|
||||
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
client_id_issued_at: int | None = None
|
||||
client_secret_expires_at: int | None = None
|
||||
|
||||
|
||||
class OAuthMetadata(BaseModel):
|
||||
"""
|
||||
RFC 8414 OAuth 2.0 Authorization Server Metadata.
|
||||
See https://datatracker.ietf.org/doc/html/rfc8414#section-2
|
||||
"""
|
||||
|
||||
issuer: AnyHttpUrl
|
||||
authorization_endpoint: AnyHttpUrl
|
||||
token_endpoint: AnyHttpUrl
|
||||
registration_endpoint: AnyHttpUrl | None = None
|
||||
scopes_supported: list[str] | None = None
|
||||
response_types_supported: list[str] = ["code"]
|
||||
response_modes_supported: list[str] | None = None
|
||||
grant_types_supported: list[str] | None = None
|
||||
token_endpoint_auth_methods_supported: list[str] | None = None
|
||||
token_endpoint_auth_signing_alg_values_supported: list[str] | None = None
|
||||
service_documentation: AnyHttpUrl | None = None
|
||||
ui_locales_supported: list[str] | None = None
|
||||
op_policy_uri: AnyHttpUrl | None = None
|
||||
op_tos_uri: AnyHttpUrl | None = None
|
||||
revocation_endpoint: AnyHttpUrl | None = None
|
||||
revocation_endpoint_auth_methods_supported: list[str] | None = None
|
||||
revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None
|
||||
introspection_endpoint: AnyHttpUrl | None = None
|
||||
introspection_endpoint_auth_methods_supported: list[str] | None = None
|
||||
introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None
|
||||
code_challenge_methods_supported: list[str] | None = None
|
||||
client_id_metadata_document_supported: bool | None = None
|
||||
|
||||
|
||||
class ProtectedResourceMetadata(BaseModel):
|
||||
"""
|
||||
RFC 9728 OAuth 2.0 Protected Resource Metadata.
|
||||
See https://datatracker.ietf.org/doc/html/rfc9728#section-2
|
||||
"""
|
||||
|
||||
resource: AnyHttpUrl
|
||||
authorization_servers: list[AnyHttpUrl] = Field(..., min_length=1)
|
||||
jwks_uri: AnyHttpUrl | None = None
|
||||
scopes_supported: list[str] | None = None
|
||||
bearer_methods_supported: list[str] | None = Field(default=["header"]) # MCP only supports header method
|
||||
resource_signing_alg_values_supported: list[str] | None = None
|
||||
resource_name: str | None = None
|
||||
resource_documentation: AnyHttpUrl | None = None
|
||||
resource_policy_uri: AnyHttpUrl | None = None
|
||||
resource_tos_uri: AnyHttpUrl | None = None
|
||||
# tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity
|
||||
tls_client_certificate_bound_access_tokens: bool | None = None
|
||||
authorization_details_types_supported: list[str] | None = None
|
||||
dpop_signing_alg_values_supported: list[str] | None = None
|
||||
# dpop_bound_access_tokens_required default is False, but ommited here for clarity
|
||||
dpop_bound_access_tokens_required: bool | None = None
|
||||
85
.venv/lib/python3.11/site-packages/mcp/shared/auth_utils.py
Normal file
85
.venv/lib/python3.11/site-packages/mcp/shared/auth_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636)."""
|
||||
|
||||
import time
|
||||
from urllib.parse import urlparse, urlsplit, urlunsplit
|
||||
|
||||
from pydantic import AnyUrl, HttpUrl
|
||||
|
||||
|
||||
def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str:
|
||||
"""Convert server URL to canonical resource URL per RFC 8707.
|
||||
|
||||
RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component".
|
||||
Returns absolute URI with lowercase scheme/host for canonical form.
|
||||
|
||||
Args:
|
||||
url: Server URL to convert
|
||||
|
||||
Returns:
|
||||
Canonical resource URL string
|
||||
"""
|
||||
# Convert to string if needed
|
||||
url_str = str(url)
|
||||
|
||||
# Parse the URL and remove fragment, create canonical form
|
||||
parsed = urlsplit(url_str)
|
||||
canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment=""))
|
||||
|
||||
return canonical
|
||||
|
||||
|
||||
def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool:
|
||||
"""Check if a requested resource URL matches a configured resource URL.
|
||||
|
||||
A requested resource matches if it has the same scheme, domain, port,
|
||||
and its path starts with the configured resource's path. This allows
|
||||
hierarchical matching where a token for a parent resource can be used
|
||||
for child resources.
|
||||
|
||||
Args:
|
||||
requested_resource: The resource URL being requested
|
||||
configured_resource: The resource URL that has been configured
|
||||
|
||||
Returns:
|
||||
True if the requested resource matches the configured resource
|
||||
"""
|
||||
# Parse both URLs
|
||||
requested = urlparse(requested_resource)
|
||||
configured = urlparse(configured_resource)
|
||||
|
||||
# Compare scheme, host, and port (origin)
|
||||
if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower():
|
||||
return False
|
||||
|
||||
# Handle cases like requested=/foo and configured=/foo/
|
||||
requested_path = requested.path
|
||||
configured_path = configured.path
|
||||
|
||||
# If requested path is shorter, it cannot be a child
|
||||
if len(requested_path) < len(configured_path):
|
||||
return False
|
||||
|
||||
# Check if the requested path starts with the configured path
|
||||
# Ensure both paths end with / for proper comparison
|
||||
# This ensures that paths like "/api123" don't incorrectly match "/api"
|
||||
if not requested_path.endswith("/"):
|
||||
requested_path += "/"
|
||||
if not configured_path.endswith("/"):
|
||||
configured_path += "/"
|
||||
|
||||
return requested_path.startswith(configured_path)
|
||||
|
||||
|
||||
def calculate_token_expiry(expires_in: int | str | None) -> float | None:
|
||||
"""Calculate token expiry timestamp from expires_in seconds.
|
||||
|
||||
Args:
|
||||
expires_in: Seconds until token expiration (may be string from some servers)
|
||||
|
||||
Returns:
|
||||
Unix timestamp when token expires, or None if no expiry specified
|
||||
"""
|
||||
if expires_in is None:
|
||||
return None # pragma: no cover
|
||||
# Defensive: handle servers that return expires_in as string
|
||||
return time.time() + int(expires_in)
|
||||
32
.venv/lib/python3.11/site-packages/mcp/shared/context.py
Normal file
32
.venv/lib/python3.11/site-packages/mcp/shared/context.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Request context for MCP handlers.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Generic
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from mcp.shared.message import CloseSSEStreamCallback
|
||||
from mcp.shared.session import BaseSession
|
||||
from mcp.types import RequestId, RequestParams
|
||||
|
||||
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
|
||||
LifespanContextT = TypeVar("LifespanContextT")
|
||||
RequestT = TypeVar("RequestT", default=Any)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
|
||||
request_id: RequestId
|
||||
meta: RequestParams.Meta | None
|
||||
session: SessionT
|
||||
lifespan_context: LifespanContextT
|
||||
# NOTE: This is typed as Any to avoid circular imports. The actual type is
|
||||
# mcp.server.experimental.request_context.Experimental, but importing it here
|
||||
# triggers mcp.server.__init__ -> fastmcp -> tools -> back to this module.
|
||||
# The Server sets this to an Experimental instance at runtime.
|
||||
experimental: Any = field(default=None)
|
||||
request: RequestT | None = None
|
||||
close_sse_stream: CloseSSEStreamCallback | None = None
|
||||
close_standalone_sse_stream: CloseSSEStreamCallback | None = None
|
||||
71
.venv/lib/python3.11/site-packages/mcp/shared/exceptions.py
Normal file
71
.venv/lib/python3.11/site-packages/mcp/shared/exceptions.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData
|
||||
|
||||
|
||||
class McpError(Exception):
|
||||
"""
|
||||
Exception type raised when an error arrives over an MCP connection.
|
||||
"""
|
||||
|
||||
error: ErrorData
|
||||
|
||||
def __init__(self, error: ErrorData):
|
||||
"""Initialize McpError."""
|
||||
super().__init__(error.message)
|
||||
self.error = error
|
||||
|
||||
|
||||
class UrlElicitationRequiredError(McpError):
|
||||
"""
|
||||
Specialized error for when a tool requires URL mode elicitation(s) before proceeding.
|
||||
|
||||
Servers can raise this error from tool handlers to indicate that the client
|
||||
must complete one or more URL elicitations before the request can be processed.
|
||||
|
||||
Example:
|
||||
raise UrlElicitationRequiredError([
|
||||
ElicitRequestURLParams(
|
||||
mode="url",
|
||||
message="Authorization required for your files",
|
||||
url="https://example.com/oauth/authorize",
|
||||
elicitationId="auth-001"
|
||||
)
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
elicitations: list[ElicitRequestURLParams],
|
||||
message: str | None = None,
|
||||
):
|
||||
"""Initialize UrlElicitationRequiredError."""
|
||||
if message is None:
|
||||
message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required"
|
||||
|
||||
self._elicitations = elicitations
|
||||
|
||||
error = ErrorData(
|
||||
code=URL_ELICITATION_REQUIRED,
|
||||
message=message,
|
||||
data={"elicitations": [e.model_dump(by_alias=True, exclude_none=True) for e in elicitations]},
|
||||
)
|
||||
super().__init__(error)
|
||||
|
||||
@property
|
||||
def elicitations(self) -> list[ElicitRequestURLParams]:
|
||||
"""The list of URL elicitations required before the request can proceed."""
|
||||
return self._elicitations
|
||||
|
||||
@classmethod
|
||||
def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError:
|
||||
"""Reconstruct from an ErrorData received over the wire."""
|
||||
if error.code != URL_ELICITATION_REQUIRED:
|
||||
raise ValueError(f"Expected error code {URL_ELICITATION_REQUIRED}, got {error.code}")
|
||||
|
||||
data = cast(dict[str, Any], error.data or {})
|
||||
raw_elicitations = cast(list[dict[str, Any]], data.get("elicitations", []))
|
||||
elicitations = [ElicitRequestURLParams.model_validate(e) for e in raw_elicitations]
|
||||
return cls(elicitations, error.message)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Pure experimental MCP features (no server dependencies).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
For server-integrated experimental features, use mcp.server.experimental.
|
||||
"""
|
||||
Binary file not shown.
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Pure task state management for MCP.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Import directly from submodules:
|
||||
- mcp.shared.experimental.tasks.store.TaskStore
|
||||
- mcp.shared.experimental.tasks.context.TaskContext
|
||||
- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore
|
||||
- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue
|
||||
- mcp.shared.experimental.tasks.helpers.is_terminal
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Tasks capability checking utilities.
|
||||
|
||||
This module provides functions for checking and requiring task-related
|
||||
capabilities. All tasks capability logic is centralized here to keep
|
||||
the main session code clean.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
INVALID_REQUEST,
|
||||
ClientCapabilities,
|
||||
ClientTasksCapability,
|
||||
ErrorData,
|
||||
)
|
||||
|
||||
|
||||
def check_tasks_capability(
|
||||
required: ClientTasksCapability,
|
||||
client: ClientTasksCapability,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if client's tasks capability matches the required capability.
|
||||
|
||||
Args:
|
||||
required: The capability being checked for
|
||||
client: The client's declared capabilities
|
||||
|
||||
Returns:
|
||||
True if client has the required capability, False otherwise
|
||||
"""
|
||||
if required.requests is None:
|
||||
return True
|
||||
if client.requests is None:
|
||||
return False
|
||||
|
||||
# Check elicitation.create
|
||||
if required.requests.elicitation is not None:
|
||||
if client.requests.elicitation is None:
|
||||
return False
|
||||
if required.requests.elicitation.create is not None:
|
||||
if client.requests.elicitation.create is None:
|
||||
return False
|
||||
|
||||
# Check sampling.createMessage
|
||||
if required.requests.sampling is not None:
|
||||
if client.requests.sampling is None:
|
||||
return False
|
||||
if required.requests.sampling.createMessage is not None:
|
||||
if client.requests.sampling.createMessage is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool:
|
||||
"""Check if capabilities include task-augmented elicitation support."""
|
||||
if caps.tasks is None:
|
||||
return False
|
||||
if caps.tasks.requests is None:
|
||||
return False
|
||||
if caps.tasks.requests.elicitation is None:
|
||||
return False
|
||||
return caps.tasks.requests.elicitation.create is not None
|
||||
|
||||
|
||||
def has_task_augmented_sampling(caps: ClientCapabilities) -> bool:
|
||||
"""Check if capabilities include task-augmented sampling support."""
|
||||
if caps.tasks is None:
|
||||
return False
|
||||
if caps.tasks.requests is None:
|
||||
return False
|
||||
if caps.tasks.requests.sampling is None:
|
||||
return False
|
||||
return caps.tasks.requests.sampling.createMessage is not None
|
||||
|
||||
|
||||
def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None:
|
||||
"""
|
||||
Raise McpError if client doesn't support task-augmented elicitation.
|
||||
|
||||
Args:
|
||||
client_caps: The client's declared capabilities, or None if not initialized
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
"""
|
||||
if client_caps is None or not has_task_augmented_elicitation(client_caps):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support task-augmented elicitation",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None:
|
||||
"""
|
||||
Raise McpError if client doesn't support task-augmented sampling.
|
||||
|
||||
Args:
|
||||
client_caps: The client's declared capabilities, or None if not initialized
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling
|
||||
"""
|
||||
if client_caps is None or not has_task_augmented_sampling(client_caps):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support task-augmented sampling",
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
TaskContext - Pure task state management.
|
||||
|
||||
This module provides TaskContext, which manages task state without any
|
||||
server/session dependencies. It can be used standalone for distributed
|
||||
workers or wrapped by ServerTaskContext for full server integration.
|
||||
"""
|
||||
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task
|
||||
|
||||
|
||||
class TaskContext:
|
||||
"""
|
||||
Pure task state management - no session dependencies.
|
||||
|
||||
This class handles:
|
||||
- Task state (status, result)
|
||||
- Cancellation tracking
|
||||
- Store interactions
|
||||
|
||||
For server-integrated features (elicit, create_message, notifications),
|
||||
use ServerTaskContext from mcp.server.experimental.
|
||||
|
||||
Example (distributed worker):
|
||||
async def worker_job(task_id: str):
|
||||
store = RedisTaskStore(redis_url)
|
||||
task = await store.get_task(task_id)
|
||||
ctx = TaskContext(task=task, store=store)
|
||||
|
||||
await ctx.update_status("Working...")
|
||||
result = await do_work()
|
||||
await ctx.complete(result)
|
||||
"""
|
||||
|
||||
def __init__(self, task: Task, store: TaskStore):
|
||||
self._task = task
|
||||
self._store = store
|
||||
self._cancelled = False
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""The task identifier."""
|
||||
return self._task.taskId
|
||||
|
||||
@property
|
||||
def task(self) -> Task:
|
||||
"""The current task state."""
|
||||
return self._task
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Whether cancellation has been requested."""
|
||||
return self._cancelled
|
||||
|
||||
def request_cancellation(self) -> None:
|
||||
"""
|
||||
Request cancellation of this task.
|
||||
|
||||
This sets is_cancelled=True. Task work should check this
|
||||
periodically and exit gracefully if set.
|
||||
"""
|
||||
self._cancelled = True
|
||||
|
||||
async def update_status(self, message: str) -> None:
|
||||
"""
|
||||
Update the task's status message.
|
||||
|
||||
Args:
|
||||
message: The new status message
|
||||
"""
|
||||
self._task = await self._store.update_task(
|
||||
self.task_id,
|
||||
status_message=message,
|
||||
)
|
||||
|
||||
async def complete(self, result: Result) -> None:
|
||||
"""
|
||||
Mark the task as completed with the given result.
|
||||
|
||||
Args:
|
||||
result: The task result
|
||||
"""
|
||||
await self._store.store_result(self.task_id, result)
|
||||
self._task = await self._store.update_task(
|
||||
self.task_id,
|
||||
status=TASK_STATUS_COMPLETED,
|
||||
)
|
||||
|
||||
async def fail(self, error: str) -> None:
|
||||
"""
|
||||
Mark the task as failed with an error message.
|
||||
|
||||
Args:
|
||||
error: The error message
|
||||
"""
|
||||
self._task = await self._store.update_task(
|
||||
self.task_id,
|
||||
status=TASK_STATUS_FAILED,
|
||||
status_message=error,
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Helper functions for pure task management.
|
||||
|
||||
These helpers work with pure TaskContext and don't require server dependencies.
|
||||
For server-integrated task helpers, use mcp.server.experimental.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.context import TaskContext
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import (
|
||||
INVALID_PARAMS,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_COMPLETED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_WORKING,
|
||||
CancelTaskResult,
|
||||
ErrorData,
|
||||
Task,
|
||||
TaskMetadata,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
# Metadata key for model-immediate-response (per MCP spec)
|
||||
# Servers MAY include this in CreateTaskResult._meta to provide an immediate
|
||||
# response string while the task executes in the background.
|
||||
MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response"
|
||||
|
||||
# Metadata key for associating requests with a task (per MCP spec)
|
||||
RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task"
|
||||
|
||||
|
||||
def is_terminal(status: TaskStatus) -> bool:
|
||||
"""
|
||||
Check if a task status represents a terminal state.
|
||||
|
||||
Terminal states are those where the task has finished and will not change.
|
||||
|
||||
Args:
|
||||
status: The task status to check
|
||||
|
||||
Returns:
|
||||
True if the status is terminal (completed, failed, or cancelled)
|
||||
"""
|
||||
return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED)
|
||||
|
||||
|
||||
async def cancel_task(
|
||||
store: TaskStore,
|
||||
task_id: str,
|
||||
) -> CancelTaskResult:
|
||||
"""
|
||||
Cancel a task with spec-compliant validation.
|
||||
|
||||
Per spec: "Receivers MUST reject cancellation of terminal status tasks
|
||||
with -32602 (Invalid params)"
|
||||
|
||||
This helper validates that the task exists and is not in a terminal state
|
||||
before setting it to "cancelled".
|
||||
|
||||
Args:
|
||||
store: The task store
|
||||
task_id: The task identifier to cancel
|
||||
|
||||
Returns:
|
||||
CancelTaskResult with the cancelled task state
|
||||
|
||||
Raises:
|
||||
McpError: With INVALID_PARAMS (-32602) if:
|
||||
- Task does not exist
|
||||
- Task is already in a terminal state (completed, failed, cancelled)
|
||||
|
||||
Example:
|
||||
@server.experimental.cancel_task()
|
||||
async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult:
|
||||
return await cancel_task(store, request.params.taskId)
|
||||
"""
|
||||
task = await store.get_task(task_id)
|
||||
if task is None:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Task not found: {task_id}",
|
||||
)
|
||||
)
|
||||
|
||||
if is_terminal(task.status):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Cannot cancel task in terminal state '{task.status}'",
|
||||
)
|
||||
)
|
||||
|
||||
# Update task to cancelled status
|
||||
cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED)
|
||||
return CancelTaskResult(**cancelled_task.model_dump())
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""Generate a unique task ID."""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def create_task_state(
|
||||
metadata: TaskMetadata,
|
||||
task_id: str | None = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Create a Task object with initial state.
|
||||
|
||||
This is a helper for TaskStore implementations.
|
||||
|
||||
Args:
|
||||
metadata: Task metadata
|
||||
task_id: Optional task ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
A new Task in "working" status
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
return Task(
|
||||
taskId=task_id or generate_task_id(),
|
||||
status=TASK_STATUS_WORKING,
|
||||
createdAt=now,
|
||||
lastUpdatedAt=now,
|
||||
ttl=metadata.ttl,
|
||||
pollInterval=500, # Default 500ms poll interval
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def task_execution(
|
||||
task_id: str,
|
||||
store: TaskStore,
|
||||
) -> AsyncIterator[TaskContext]:
|
||||
"""
|
||||
Context manager for safe task execution (pure, no server dependencies).
|
||||
|
||||
Loads a task from the store and provides a TaskContext for the work.
|
||||
If an unhandled exception occurs, the task is automatically marked as failed
|
||||
and the exception is suppressed (since the failure is captured in task state).
|
||||
|
||||
This is useful for distributed workers that don't have a server session.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier to execute
|
||||
store: The task store (must be accessible by the worker)
|
||||
|
||||
Yields:
|
||||
TaskContext for updating status and completing/failing the task
|
||||
|
||||
Raises:
|
||||
ValueError: If the task is not found in the store
|
||||
|
||||
Example (distributed worker):
|
||||
async def worker_process(task_id: str):
|
||||
store = RedisTaskStore(redis_url)
|
||||
async with task_execution(task_id, store) as ctx:
|
||||
await ctx.update_status("Working...")
|
||||
result = await do_work()
|
||||
await ctx.complete(result)
|
||||
"""
|
||||
task = await store.get_task(task_id)
|
||||
if task is None:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
ctx = TaskContext(task, store)
|
||||
try:
|
||||
yield ctx
|
||||
except Exception as e:
|
||||
# Auto-fail the task if an exception occurs and task isn't already terminal
|
||||
# Exception is suppressed since failure is captured in task state
|
||||
if not is_terminal(ctx.task.status):
|
||||
await ctx.fail(str(e))
|
||||
# Don't re-raise - the failure is recorded in task state
|
||||
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
In-memory implementation of TaskStore for demonstration purposes.
|
||||
|
||||
This implementation stores all tasks in memory and provides automatic cleanup
|
||||
based on the TTL duration specified in the task metadata using lazy expiration.
|
||||
|
||||
Note: This is not suitable for production use as all data is lost on restart.
|
||||
For production, consider implementing TaskStore with a database or distributed cache.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import Result, Task, TaskMetadata, TaskStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredTask:
|
||||
"""Internal storage representation of a task."""
|
||||
|
||||
task: Task
|
||||
result: Result | None = None
|
||||
# Time when this task should be removed (None = never)
|
||||
expires_at: datetime | None = field(default=None)
|
||||
|
||||
|
||||
class InMemoryTaskStore(TaskStore):
|
||||
"""
|
||||
A simple in-memory implementation of TaskStore.
|
||||
|
||||
Features:
|
||||
- Automatic TTL-based cleanup (lazy expiration)
|
||||
- Thread-safe for single-process async use
|
||||
- Pagination support for list_tasks
|
||||
|
||||
Limitations:
|
||||
- All data lost on restart
|
||||
- Not suitable for distributed systems
|
||||
- No persistence
|
||||
|
||||
For production, implement TaskStore with Redis, PostgreSQL, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, page_size: int = 10) -> None:
|
||||
self._tasks: dict[str, StoredTask] = {}
|
||||
self._page_size = page_size
|
||||
self._update_events: dict[str, anyio.Event] = {}
|
||||
|
||||
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
|
||||
"""Calculate expiry time from TTL in milliseconds."""
|
||||
if ttl_ms is None:
|
||||
return None
|
||||
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
|
||||
|
||||
def _is_expired(self, stored: StoredTask) -> bool:
|
||||
"""Check if a task has expired."""
|
||||
if stored.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= stored.expires_at
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove all expired tasks. Called lazily during access operations."""
|
||||
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
|
||||
for task_id in expired_ids:
|
||||
del self._tasks[task_id]
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
metadata: TaskMetadata,
|
||||
task_id: str | None = None,
|
||||
) -> Task:
|
||||
"""Create a new task with the given metadata."""
|
||||
# Cleanup expired tasks on access
|
||||
self._cleanup_expired()
|
||||
|
||||
task = create_task_state(metadata, task_id)
|
||||
|
||||
if task.taskId in self._tasks:
|
||||
raise ValueError(f"Task with ID {task.taskId} already exists")
|
||||
|
||||
stored = StoredTask(
|
||||
task=task,
|
||||
expires_at=self._calculate_expiry(metadata.ttl),
|
||||
)
|
||||
self._tasks[task.taskId] = stored
|
||||
|
||||
# Return a copy to prevent external modification
|
||||
return Task(**task.model_dump())
|
||||
|
||||
async def get_task(self, task_id: str) -> Task | None:
|
||||
"""Get a task by ID."""
|
||||
# Cleanup expired tasks on access
|
||||
self._cleanup_expired()
|
||||
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
# Return a copy to prevent external modification
|
||||
return Task(**stored.task.model_dump())
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus | None = None,
|
||||
status_message: str | None = None,
|
||||
) -> Task:
|
||||
"""Update a task's status and/or message."""
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
raise ValueError(f"Task with ID {task_id} not found")
|
||||
|
||||
# Per spec: Terminal states MUST NOT transition to any other status
|
||||
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
|
||||
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
|
||||
|
||||
status_changed = False
|
||||
if status is not None and stored.task.status != status:
|
||||
stored.task.status = status
|
||||
status_changed = True
|
||||
|
||||
if status_message is not None:
|
||||
stored.task.statusMessage = status_message
|
||||
|
||||
# Update lastUpdatedAt on any change
|
||||
stored.task.lastUpdatedAt = datetime.now(timezone.utc)
|
||||
|
||||
# If task is now terminal and has TTL, reset expiry timer
|
||||
if status is not None and is_terminal(status) and stored.task.ttl is not None:
|
||||
stored.expires_at = self._calculate_expiry(stored.task.ttl)
|
||||
|
||||
# Notify waiters if status changed
|
||||
if status_changed:
|
||||
await self.notify_update(task_id)
|
||||
|
||||
return Task(**stored.task.model_dump())
|
||||
|
||||
async def store_result(self, task_id: str, result: Result) -> None:
|
||||
"""Store the result for a task."""
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
raise ValueError(f"Task with ID {task_id} not found")
|
||||
|
||||
stored.result = result
|
||||
|
||||
async def get_result(self, task_id: str) -> Result | None:
|
||||
"""Get the stored result for a task."""
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
return stored.result
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> tuple[list[Task], str | None]:
|
||||
"""List tasks with pagination."""
|
||||
# Cleanup expired tasks on access
|
||||
self._cleanup_expired()
|
||||
|
||||
all_task_ids = list(self._tasks.keys())
|
||||
|
||||
start_index = 0
|
||||
if cursor is not None:
|
||||
try:
|
||||
cursor_index = all_task_ids.index(cursor)
|
||||
start_index = cursor_index + 1
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid cursor: {cursor}")
|
||||
|
||||
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
|
||||
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
|
||||
|
||||
# Determine next cursor
|
||||
next_cursor = None
|
||||
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
|
||||
next_cursor = page_task_ids[-1]
|
||||
|
||||
return tasks, next_cursor
|
||||
|
||||
async def delete_task(self, task_id: str) -> bool:
|
||||
"""Delete a task."""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
del self._tasks[task_id]
|
||||
return True
|
||||
|
||||
async def wait_for_update(self, task_id: str) -> None:
|
||||
"""Wait until the task status changes."""
|
||||
if task_id not in self._tasks:
|
||||
raise ValueError(f"Task with ID {task_id} not found")
|
||||
|
||||
# Create a fresh event for waiting (anyio.Event can't be cleared)
|
||||
self._update_events[task_id] = anyio.Event()
|
||||
event = self._update_events[task_id]
|
||||
await event.wait()
|
||||
|
||||
async def notify_update(self, task_id: str) -> None:
|
||||
"""Signal that a task has been updated."""
|
||||
if task_id in self._update_events:
|
||||
self._update_events[task_id].set()
|
||||
|
||||
# --- Testing/debugging helpers ---
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
|
||||
self._tasks.clear()
|
||||
self._update_events.clear()
|
||||
|
||||
def get_all_tasks(self) -> list[Task]:
|
||||
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
|
||||
self._cleanup_expired()
|
||||
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]
|
||||
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
TaskMessageQueue - FIFO queue for task-related messages.
|
||||
|
||||
This implements the core message queue pattern from the MCP Tasks spec.
|
||||
When a handler needs to send a request (like elicitation) during a task-augmented
|
||||
request, the message is enqueued instead of sent directly. Messages are delivered
|
||||
to the client only through the `tasks/result` endpoint.
|
||||
|
||||
This pattern enables:
|
||||
1. Decoupling request handling from message delivery
|
||||
2. Proper bidirectional communication via the tasks/result stream
|
||||
3. Automatic status management (working <-> input_required)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedMessage:
|
||||
"""
|
||||
A message queued for delivery via tasks/result.
|
||||
|
||||
Messages are stored with their type and a resolver for requests
|
||||
that expect responses.
|
||||
"""
|
||||
|
||||
type: Literal["request", "notification"]
|
||||
"""Whether this is a request (expects response) or notification (one-way)."""
|
||||
|
||||
message: JSONRPCRequest | JSONRPCNotification
|
||||
"""The JSON-RPC message to send."""
|
||||
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""When the message was enqueued."""
|
||||
|
||||
resolver: Resolver[dict[str, Any]] | None = None
|
||||
"""Resolver to set when response arrives (only for requests)."""
|
||||
|
||||
original_request_id: RequestId | None = None
|
||||
"""The original request ID used internally, for routing responses back."""
|
||||
|
||||
|
||||
class TaskMessageQueue(ABC):
|
||||
"""
|
||||
Abstract interface for task message queuing.
|
||||
|
||||
This is a FIFO queue that stores messages to be delivered via `tasks/result`.
|
||||
When a task-augmented handler calls elicit() or sends a notification, the
|
||||
message is enqueued here instead of being sent directly to the client.
|
||||
|
||||
The `tasks/result` handler then dequeues and sends these messages through
|
||||
the transport, with `relatedRequestId` set to the tasks/result request ID
|
||||
so responses are routed correctly.
|
||||
|
||||
Implementations can use in-memory storage, Redis, etc.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
|
||||
"""
|
||||
Add a message to the queue for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
message: The message to enqueue
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dequeue(self, task_id: str) -> QueuedMessage | None:
|
||||
"""
|
||||
Remove and return the next message from the queue.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The next message, or None if queue is empty
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def peek(self, task_id: str) -> QueuedMessage | None:
|
||||
"""
|
||||
Return the next message without removing it.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The next message, or None if queue is empty
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def is_empty(self, task_id: str) -> bool:
|
||||
"""
|
||||
Check if the queue is empty for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
True if no messages are queued
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, task_id: str) -> list[QueuedMessage]:
|
||||
"""
|
||||
Remove and return all messages from the queue.
|
||||
|
||||
This is useful for cleanup when a task is cancelled or completed.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
All queued messages (may be empty)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def wait_for_message(self, task_id: str) -> None:
|
||||
"""
|
||||
Wait until a message is available in the queue.
|
||||
|
||||
This blocks until either:
|
||||
1. A message is enqueued for this task
|
||||
2. The wait is cancelled
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify_message_available(self, task_id: str) -> None:
|
||||
"""
|
||||
Signal that a message is available for a task.
|
||||
|
||||
This wakes up any coroutines waiting in wait_for_message().
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
"""
|
||||
|
||||
|
||||
class InMemoryTaskMessageQueue(TaskMessageQueue):
|
||||
"""
|
||||
In-memory implementation of TaskMessageQueue.
|
||||
|
||||
This is suitable for single-process servers. For distributed systems,
|
||||
implement TaskMessageQueue with Redis, RabbitMQ, etc.
|
||||
|
||||
Features:
|
||||
- FIFO ordering per task
|
||||
- Async wait for message availability
|
||||
- Thread-safe for single-process async use
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queues: dict[str, list[QueuedMessage]] = {}
|
||||
self._events: dict[str, anyio.Event] = {}
|
||||
|
||||
def _get_queue(self, task_id: str) -> list[QueuedMessage]:
|
||||
"""Get or create the queue for a task."""
|
||||
if task_id not in self._queues:
|
||||
self._queues[task_id] = []
|
||||
return self._queues[task_id]
|
||||
|
||||
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
|
||||
"""Add a message to the queue."""
|
||||
queue = self._get_queue(task_id)
|
||||
queue.append(message)
|
||||
# Signal that a message is available
|
||||
await self.notify_message_available(task_id)
|
||||
|
||||
async def dequeue(self, task_id: str) -> QueuedMessage | None:
|
||||
"""Remove and return the next message."""
|
||||
queue = self._get_queue(task_id)
|
||||
if not queue:
|
||||
return None
|
||||
return queue.pop(0)
|
||||
|
||||
async def peek(self, task_id: str) -> QueuedMessage | None:
|
||||
"""Return the next message without removing it."""
|
||||
queue = self._get_queue(task_id)
|
||||
if not queue:
|
||||
return None
|
||||
return queue[0]
|
||||
|
||||
async def is_empty(self, task_id: str) -> bool:
|
||||
"""Check if the queue is empty."""
|
||||
queue = self._get_queue(task_id)
|
||||
return len(queue) == 0
|
||||
|
||||
async def clear(self, task_id: str) -> list[QueuedMessage]:
|
||||
"""Remove and return all messages."""
|
||||
queue = self._get_queue(task_id)
|
||||
messages = list(queue)
|
||||
queue.clear()
|
||||
return messages
|
||||
|
||||
async def wait_for_message(self, task_id: str) -> None:
|
||||
"""Wait until a message is available."""
|
||||
# Check if there are already messages
|
||||
if not await self.is_empty(task_id):
|
||||
return
|
||||
|
||||
# Create a fresh event for waiting (anyio.Event can't be cleared)
|
||||
self._events[task_id] = anyio.Event()
|
||||
event = self._events[task_id]
|
||||
|
||||
# Double-check after creating event (avoid race condition)
|
||||
if not await self.is_empty(task_id):
|
||||
return
|
||||
|
||||
# Wait for a new message
|
||||
await event.wait()
|
||||
|
||||
async def notify_message_available(self, task_id: str) -> None:
|
||||
"""Signal that a message is available."""
|
||||
if task_id in self._events:
|
||||
self._events[task_id].set()
|
||||
|
||||
def cleanup(self, task_id: str | None = None) -> None:
|
||||
"""
|
||||
Clean up queues and events.
|
||||
|
||||
Args:
|
||||
task_id: If provided, clean up only this task. Otherwise clean up all.
|
||||
"""
|
||||
if task_id is not None:
|
||||
self._queues.pop(task_id, None)
|
||||
self._events.pop(task_id, None)
|
||||
else:
|
||||
self._queues.clear()
|
||||
self._events.clear()
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Shared polling utilities for task operations.
|
||||
|
||||
This module provides generic polling logic that works for both client→server
|
||||
and server→client task polling.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.shared.experimental.tasks.helpers import is_terminal
|
||||
from mcp.types import GetTaskResult
|
||||
|
||||
|
||||
async def poll_until_terminal(
|
||||
get_task: Callable[[str], Awaitable[GetTaskResult]],
|
||||
task_id: str,
|
||||
default_interval_ms: int = 500,
|
||||
) -> AsyncIterator[GetTaskResult]:
|
||||
"""
|
||||
Poll a task until it reaches terminal status.
|
||||
|
||||
This is a generic utility that works for both client→server and server→client
|
||||
polling. The caller provides the get_task function appropriate for their direction.
|
||||
|
||||
Args:
|
||||
get_task: Async function that takes task_id and returns GetTaskResult
|
||||
task_id: The task to poll
|
||||
default_interval_ms: Fallback poll interval if server doesn't specify
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
"""
|
||||
while True:
|
||||
status = await get_task(task_id)
|
||||
yield status
|
||||
|
||||
if is_terminal(status.status):
|
||||
break
|
||||
|
||||
interval_ms = status.pollInterval if status.pollInterval is not None else default_interval_ms
|
||||
await anyio.sleep(interval_ms / 1000)
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Resolver - An anyio-compatible future-like object for async result passing.
|
||||
|
||||
This provides a simple way to pass a result (or exception) from one coroutine
|
||||
to another without depending on asyncio.Future.
|
||||
"""
|
||||
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
import anyio
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Resolver(Generic[T]):
|
||||
"""
|
||||
A simple resolver for passing results between coroutines.
|
||||
|
||||
Unlike asyncio.Future, this works with any anyio-compatible async backend.
|
||||
|
||||
Usage:
|
||||
resolver: Resolver[str] = Resolver()
|
||||
|
||||
# In one coroutine:
|
||||
resolver.set_result("hello")
|
||||
|
||||
# In another coroutine:
|
||||
result = await resolver.wait() # returns "hello"
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event = anyio.Event()
|
||||
self._value: T | None = None
|
||||
self._exception: BaseException | None = None
|
||||
|
||||
def set_result(self, value: T) -> None:
|
||||
"""Set the result value and wake up waiters."""
|
||||
if self._event.is_set():
|
||||
raise RuntimeError("Resolver already completed")
|
||||
self._value = value
|
||||
self._event.set()
|
||||
|
||||
def set_exception(self, exc: BaseException) -> None:
|
||||
"""Set an exception and wake up waiters."""
|
||||
if self._event.is_set():
|
||||
raise RuntimeError("Resolver already completed")
|
||||
self._exception = exc
|
||||
self._event.set()
|
||||
|
||||
async def wait(self) -> T:
|
||||
"""Wait for the result and return it, or raise the exception."""
|
||||
await self._event.wait()
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
# If we reach here, set_result() was called, so _value is set
|
||||
return cast(T, self._value)
|
||||
|
||||
def done(self) -> bool:
|
||||
"""Return True if the resolver has been completed."""
|
||||
return self._event.is_set()
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
TaskStore - Abstract interface for task state storage.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from mcp.types import Result, Task, TaskMetadata, TaskStatus
|
||||
|
||||
|
||||
class TaskStore(ABC):
|
||||
"""
|
||||
Abstract interface for task state storage.
|
||||
|
||||
This is a pure storage interface - it doesn't manage execution.
|
||||
Implementations can use in-memory storage, databases, Redis, etc.
|
||||
|
||||
All methods are async to support various backends.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_task(
|
||||
self,
|
||||
metadata: TaskMetadata,
|
||||
task_id: str | None = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Create a new task.
|
||||
|
||||
Args:
|
||||
metadata: Task metadata (ttl, etc.)
|
||||
task_id: Optional task ID. If None, implementation should generate one.
|
||||
|
||||
Returns:
|
||||
The created Task with status="working"
|
||||
|
||||
Raises:
|
||||
ValueError: If task_id already exists
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_task(self, task_id: str) -> Task | None:
|
||||
"""
|
||||
Get a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The Task, or None if not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus | None = None,
|
||||
status_message: str | None = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Update a task's status and/or message.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
status: New status (if changing)
|
||||
status_message: New status message (if changing)
|
||||
|
||||
Returns:
|
||||
The updated Task
|
||||
|
||||
Raises:
|
||||
ValueError: If task not found
|
||||
ValueError: If attempting to transition from a terminal status
|
||||
(completed, failed, cancelled). Per spec, terminal states
|
||||
MUST NOT transition to any other status.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def store_result(self, task_id: str, result: Result) -> None:
|
||||
"""
|
||||
Store the result for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result: The result to store
|
||||
|
||||
Raises:
|
||||
ValueError: If task not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_result(self, task_id: str) -> Result | None:
|
||||
"""
|
||||
Get the stored result for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The stored Result, or None if not available
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> tuple[list[Task], str | None]:
|
||||
"""
|
||||
List tasks with pagination.
|
||||
|
||||
Args:
|
||||
cursor: Optional cursor for pagination
|
||||
|
||||
Returns:
|
||||
Tuple of (tasks, next_cursor). next_cursor is None if no more pages.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def wait_for_update(self, task_id: str) -> None:
|
||||
"""
|
||||
Wait until the task status changes.
|
||||
|
||||
This blocks until either:
|
||||
1. The task status changes
|
||||
2. The wait is cancelled
|
||||
|
||||
Used by tasks/result to wait for task completion or status changes.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Raises:
|
||||
ValueError: If task not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify_update(self, task_id: str) -> None:
|
||||
"""
|
||||
Signal that a task has been updated.
|
||||
|
||||
This wakes up any coroutines waiting in wait_for_update().
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
"""
|
||||
98
.venv/lib/python3.11/site-packages/mcp/shared/memory.py
Normal file
98
.venv/lib/python3.11/site-packages/mcp/shared/memory.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
In-memory transports
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
||||
from mcp.server import Server
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]:
|
||||
"""
|
||||
Creates a pair of bidirectional memory streams for client-server communication.
|
||||
|
||||
Returns:
|
||||
A tuple of (client_streams, server_streams) where each is a tuple of
|
||||
(read_stream, write_stream)
|
||||
"""
|
||||
# Create streams for both directions
|
||||
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
||||
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
|
||||
|
||||
client_streams = (server_to_client_receive, client_to_server_send)
|
||||
server_streams = (client_to_server_receive, server_to_client_send)
|
||||
|
||||
async with (
|
||||
server_to_client_receive,
|
||||
client_to_server_send,
|
||||
client_to_server_receive,
|
||||
server_to_client_send,
|
||||
):
|
||||
yield client_streams, server_streams
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_connected_server_and_client_session(
|
||||
server: Server[Any] | FastMCP,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
raise_exceptions: bool = False,
|
||||
elicitation_callback: ElicitationFnT | None = None,
|
||||
) -> AsyncGenerator[ClientSession, None]:
|
||||
"""Creates a ClientSession that is connected to a running MCP server."""
|
||||
|
||||
# TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport",
|
||||
# and we should expose a method in the `FastMCP` so we don't access a private attribute.
|
||||
if isinstance(server, FastMCP): # pragma: no cover
|
||||
server = server._mcp_server # type: ignore[reportPrivateUsage]
|
||||
|
||||
async with create_client_server_memory_streams() as (client_streams, server_streams):
|
||||
client_read, client_write = client_streams
|
||||
server_read, server_write = server_streams
|
||||
|
||||
# Create a cancel scope for the server task
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(
|
||||
lambda: server.run(
|
||||
server_read,
|
||||
server_write,
|
||||
server.create_initialization_options(),
|
||||
raise_exceptions=raise_exceptions,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async with ClientSession(
|
||||
read_stream=client_read,
|
||||
write_stream=client_write,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
sampling_callback=sampling_callback,
|
||||
list_roots_callback=list_roots_callback,
|
||||
logging_callback=logging_callback,
|
||||
message_handler=message_handler,
|
||||
client_info=client_info,
|
||||
elicitation_callback=elicitation_callback,
|
||||
) as client_session:
|
||||
await client_session.initialize()
|
||||
yield client_session
|
||||
finally: # pragma: no cover
|
||||
tg.cancel_scope.cancel()
|
||||
50
.venv/lib/python3.11/site-packages/mcp/shared/message.py
Normal file
50
.venv/lib/python3.11/site-packages/mcp/shared/message.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Message wrapper with metadata support.
|
||||
|
||||
This module defines a wrapper type that combines JSONRPCMessage with metadata
|
||||
to support transport-specific features like resumability.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mcp.types import JSONRPCMessage, RequestId
|
||||
|
||||
ResumptionToken = str
|
||||
|
||||
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
|
||||
|
||||
# Callback type for closing SSE streams without terminating
|
||||
CloseSSEStreamCallback = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientMessageMetadata:
|
||||
"""Metadata specific to client messages."""
|
||||
|
||||
resumption_token: ResumptionToken | None = None
|
||||
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerMessageMetadata:
|
||||
"""Metadata specific to server messages."""
|
||||
|
||||
related_request_id: RequestId | None = None
|
||||
# Request-specific context (e.g., headers, auth info)
|
||||
request_context: object | None = None
|
||||
# Callback to close SSE stream for the current request without terminating
|
||||
close_sse_stream: CloseSSEStreamCallback | None = None
|
||||
# Callback to close the standalone GET SSE stream (for unsolicited notifications)
|
||||
close_standalone_sse_stream: CloseSSEStreamCallback | None = None
|
||||
|
||||
|
||||
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionMessage:
|
||||
"""A message with specific metadata for transport-specific features."""
|
||||
|
||||
message: JSONRPCMessage
|
||||
metadata: MessageMetadata = None
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Utility functions for working with metadata in MCP types.
|
||||
|
||||
These utilities are primarily intended for client-side usage to properly display
|
||||
human-readable names in user interfaces in a spec compliant way.
|
||||
"""
|
||||
|
||||
from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool
|
||||
|
||||
|
||||
def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implementation) -> str:
|
||||
"""
|
||||
Get the display name for an MCP object with proper precedence.
|
||||
|
||||
This is a client-side utility function designed to help MCP clients display
|
||||
human-readable names in their user interfaces. When servers provide a 'title'
|
||||
field, it should be preferred over the programmatic 'name' field for display.
|
||||
|
||||
For tools: title > annotations.title > name
|
||||
For other objects: title > name
|
||||
|
||||
Example:
|
||||
# In a client displaying available tools
|
||||
tools = await session.list_tools()
|
||||
for tool in tools.tools:
|
||||
display_name = get_display_name(tool)
|
||||
print(f"Available tool: {display_name}")
|
||||
|
||||
Args:
|
||||
obj: An MCP object with name and optional title fields
|
||||
|
||||
Returns:
|
||||
The display name to use for UI presentation
|
||||
"""
|
||||
if isinstance(obj, Tool):
|
||||
# Tools have special precedence: title > annotations.title > name
|
||||
if hasattr(obj, "title") and obj.title is not None:
|
||||
return obj.title
|
||||
if obj.annotations and hasattr(obj.annotations, "title") and obj.annotations.title is not None:
|
||||
return obj.annotations.title
|
||||
return obj.name
|
||||
else:
|
||||
# All other objects: title > name
|
||||
if hasattr(obj, "title") and obj.title is not None:
|
||||
return obj.title
|
||||
return obj.name
|
||||
58
.venv/lib/python3.11/site-packages/mcp/shared/progress.py
Normal file
58
.venv/lib/python3.11/site-packages/mcp/shared/progress.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mcp.shared.context import LifespanContextT, RequestContext
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
ReceiveNotificationT,
|
||||
ReceiveRequestT,
|
||||
SendNotificationT,
|
||||
SendRequestT,
|
||||
SendResultT,
|
||||
)
|
||||
from mcp.types import ProgressToken
|
||||
|
||||
|
||||
class Progress(BaseModel):
|
||||
progress: float
|
||||
total: float | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
|
||||
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
|
||||
progress_token: ProgressToken
|
||||
total: float | None
|
||||
current: float = field(default=0.0, init=False)
|
||||
|
||||
async def progress(self, amount: float, message: str | None = None) -> None:
|
||||
self.current += amount
|
||||
|
||||
await self.session.send_progress_notification(
|
||||
self.progress_token, self.current, total=self.total, message=message
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def progress(
|
||||
ctx: RequestContext[
|
||||
BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
|
||||
LifespanContextT,
|
||||
],
|
||||
total: float | None = None,
|
||||
) -> Generator[
|
||||
ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
|
||||
None,
|
||||
]:
|
||||
if ctx.meta is None or ctx.meta.progressToken is None: # pragma: no cover
|
||||
raise ValueError("No progress token provided")
|
||||
|
||||
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
|
||||
try:
|
||||
yield progress_ctx
|
||||
finally:
|
||||
pass
|
||||
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
ResponseRouter - Protocol for pluggable response routing.
|
||||
|
||||
This module defines a protocol for routing JSON-RPC responses to alternative
|
||||
handlers before falling back to the default response stream mechanism.
|
||||
|
||||
The primary use case is task-augmented requests: when a TaskSession enqueues
|
||||
a request (like elicitation), the response needs to be routed back to the
|
||||
waiting resolver instead of the normal response stream.
|
||||
|
||||
Design:
|
||||
- Protocol-based for testability and flexibility
|
||||
- Returns bool to indicate if response was handled
|
||||
- Supports both success responses and errors
|
||||
"""
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from mcp.types import ErrorData, RequestId
|
||||
|
||||
|
||||
class ResponseRouter(Protocol):
|
||||
"""
|
||||
Protocol for routing responses to alternative handlers.
|
||||
|
||||
Implementations check if they have a pending request for the given ID
|
||||
and deliver the response/error to the appropriate handler.
|
||||
|
||||
Example:
|
||||
class TaskResultHandler(ResponseRouter):
|
||||
def route_response(self, request_id, response):
|
||||
resolver = self._pending_requests.pop(request_id, None)
|
||||
if resolver:
|
||||
resolver.set_result(response)
|
||||
return True
|
||||
return False
|
||||
"""
|
||||
|
||||
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Try to route a response to a pending request handler.
|
||||
|
||||
Args:
|
||||
request_id: The JSON-RPC request ID from the response
|
||||
response: The response result data
|
||||
|
||||
Returns:
|
||||
True if the response was handled, False otherwise
|
||||
"""
|
||||
... # pragma: no cover
|
||||
|
||||
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
|
||||
"""
|
||||
Try to route an error to a pending request handler.
|
||||
|
||||
Args:
|
||||
request_id: The JSON-RPC request ID from the error response
|
||||
error: The error data
|
||||
|
||||
Returns:
|
||||
True if the error was handled, False otherwise
|
||||
"""
|
||||
... # pragma: no cover
|
||||
550
.venv/lib/python3.11/site-packages/mcp/shared/session.py
Normal file
550
.venv/lib/python3.11/site-packages/mcp/shared/session.py
Normal file
@@ -0,0 +1,550 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, Generic, Protocol, TypeVar
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.response_router import ResponseRouter
|
||||
from mcp.types import (
|
||||
CONNECTION_CLOSED,
|
||||
INVALID_PARAMS,
|
||||
CancelledNotification,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
ErrorData,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ProgressNotification,
|
||||
RequestParams,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
)
|
||||
|
||||
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
|
||||
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
|
||||
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
|
||||
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
|
||||
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
|
||||
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
|
||||
|
||||
RequestId = str | int
|
||||
|
||||
|
||||
class ProgressFnT(Protocol):
|
||||
"""Protocol for progress notification callbacks."""
|
||||
|
||||
async def __call__(
|
||||
self, progress: float, total: float | None, message: str | None
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
|
||||
"""Handles responding to MCP requests and manages request lifecycle.
|
||||
|
||||
This class MUST be used as a context manager to ensure proper cleanup and
|
||||
cancellation handling:
|
||||
|
||||
Example:
|
||||
with request_responder as resp:
|
||||
await resp.respond(result)
|
||||
|
||||
The context manager ensures:
|
||||
1. Proper cancellation scope setup and cleanup
|
||||
2. Request completion tracking
|
||||
3. Cleanup of in-flight requests
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: RequestId,
|
||||
request_meta: RequestParams.Meta | None,
|
||||
request: ReceiveRequestT,
|
||||
session: """BaseSession[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT
|
||||
]""",
|
||||
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
|
||||
message_metadata: MessageMetadata = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.request_meta = request_meta
|
||||
self.request = request
|
||||
self.message_metadata = message_metadata
|
||||
self._session = session
|
||||
self._completed = False
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._on_complete = on_complete
|
||||
self._entered = False # Track if we're in a context manager
|
||||
|
||||
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
|
||||
"""Enter the context manager, enabling request cancellation tracking."""
|
||||
self._entered = True
|
||||
self._cancel_scope = anyio.CancelScope()
|
||||
self._cancel_scope.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the context manager, performing cleanup and notifying completion."""
|
||||
try:
|
||||
if self._completed: # pragma: no branch
|
||||
self._on_complete(self)
|
||||
finally:
|
||||
self._entered = False
|
||||
if not self._cancel_scope: # pragma: no cover
|
||||
raise RuntimeError("No active cancel scope")
|
||||
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def respond(self, response: SendResultT | ErrorData) -> None:
|
||||
"""Send a response for this request.
|
||||
|
||||
Must be called within a context manager block.
|
||||
Raises:
|
||||
RuntimeError: If not used within a context manager
|
||||
AssertionError: If request was already responded to
|
||||
"""
|
||||
if not self._entered: # pragma: no cover
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
assert not self._completed, "Request already responded to"
|
||||
|
||||
if not self.cancelled: # pragma: no branch
|
||||
self._completed = True
|
||||
|
||||
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
||||
request_id=self.request_id, response=response
|
||||
)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel this request and mark it as completed."""
|
||||
if not self._entered: # pragma: no cover
|
||||
raise RuntimeError("RequestResponder must be used as a context manager")
|
||||
if not self._cancel_scope: # pragma: no cover
|
||||
raise RuntimeError("No active cancel scope")
|
||||
|
||||
self._cancel_scope.cancel()
|
||||
self._completed = True # Mark as completed so it's removed from in_flight
|
||||
# Send an error response to indicate cancellation
|
||||
await self._session._send_response( # type: ignore[reportPrivateUsage]
|
||||
request_id=self.request_id,
|
||||
response=ErrorData(code=0, message="Request cancelled", data=None),
|
||||
)
|
||||
|
||||
@property
|
||||
def in_flight(self) -> bool: # pragma: no cover
|
||||
return not self._completed and not self.cancelled
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool: # pragma: no cover
|
||||
return self._cancel_scope.cancel_called
|
||||
|
||||
|
||||
class BaseSession(
|
||||
Generic[
|
||||
SendRequestT,
|
||||
SendNotificationT,
|
||||
SendResultT,
|
||||
ReceiveRequestT,
|
||||
ReceiveNotificationT,
|
||||
],
|
||||
):
|
||||
"""
|
||||
Implements an MCP "session" on top of read/write streams, including features
|
||||
like request/response linking, notifications, and progress.
|
||||
|
||||
This class is an async context manager that automatically starts processing
|
||||
messages when entered.
|
||||
"""
|
||||
|
||||
_response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
|
||||
_request_id: int
|
||||
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
|
||||
_progress_callbacks: dict[RequestId, ProgressFnT]
|
||||
_response_routers: list["ResponseRouter"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
receive_request_type: type[ReceiveRequestT],
|
||||
receive_notification_type: type[ReceiveNotificationT],
|
||||
# If none, reading will never time out
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
) -> None:
|
||||
self._read_stream = read_stream
|
||||
self._write_stream = write_stream
|
||||
self._response_streams = {}
|
||||
self._request_id = 0
|
||||
self._receive_request_type = receive_request_type
|
||||
self._receive_notification_type = receive_notification_type
|
||||
self._session_read_timeout_seconds = read_timeout_seconds
|
||||
self._in_flight = {}
|
||||
self._progress_callbacks = {}
|
||||
self._response_routers = []
|
||||
self._exit_stack = AsyncExitStack()
|
||||
|
||||
def add_response_router(self, router: ResponseRouter) -> None:
|
||||
"""
|
||||
Register a response router to handle responses for non-standard requests.
|
||||
|
||||
Response routers are checked in order before falling back to the default
|
||||
response stream mechanism. This is used by TaskResultHandler to route
|
||||
responses for queued task requests back to their resolvers.
|
||||
|
||||
WARNING: This is an experimental API that may change without notice.
|
||||
|
||||
Args:
|
||||
router: A ResponseRouter implementation
|
||||
"""
|
||||
self._response_routers.append(router)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
self._task_group = anyio.create_task_group()
|
||||
await self._task_group.__aenter__()
|
||||
self._task_group.start_soon(self._receive_loop)
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> bool | None:
|
||||
await self._exit_stack.aclose()
|
||||
# Using BaseSession as a context manager should not block on exit (this
|
||||
# would be very surprising behavior), so make sure to cancel the tasks
|
||||
# in the task group.
|
||||
self._task_group.cancel_scope.cancel()
|
||||
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
async def send_request(
|
||||
self,
|
||||
request: SendRequestT,
|
||||
result_type: type[ReceiveResultT],
|
||||
request_read_timeout_seconds: timedelta | None = None,
|
||||
metadata: MessageMetadata = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
) -> ReceiveResultT:
|
||||
"""
|
||||
Sends a request and wait for a response. Raises an McpError if the
|
||||
response contains an error. If a request read timeout is provided, it
|
||||
will take precedence over the session read timeout.
|
||||
|
||||
Do not use this method to emit notifications! Use send_notification()
|
||||
instead.
|
||||
"""
|
||||
request_id = self._request_id
|
||||
self._request_id = request_id + 1
|
||||
|
||||
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
|
||||
self._response_streams[request_id] = response_stream
|
||||
|
||||
# Set up progress token if progress callback is provided
|
||||
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
if progress_callback is not None: # pragma: no cover
|
||||
# Use request_id as progress token
|
||||
if "params" not in request_data:
|
||||
request_data["params"] = {}
|
||||
if "_meta" not in request_data["params"]: # pragma: no branch
|
||||
request_data["params"]["_meta"] = {}
|
||||
request_data["params"]["_meta"]["progressToken"] = request_id
|
||||
# Store the callback for this request
|
||||
self._progress_callbacks[request_id] = progress_callback
|
||||
|
||||
try:
|
||||
jsonrpc_request = JSONRPCRequest(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
**request_data,
|
||||
)
|
||||
|
||||
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
|
||||
|
||||
# request read timeout takes precedence over session read timeout
|
||||
timeout = None
|
||||
if request_read_timeout_seconds is not None: # pragma: no cover
|
||||
timeout = request_read_timeout_seconds.total_seconds()
|
||||
elif self._session_read_timeout_seconds is not None: # pragma: no cover
|
||||
timeout = self._session_read_timeout_seconds.total_seconds()
|
||||
|
||||
try:
|
||||
with anyio.fail_after(timeout):
|
||||
response_or_error = await response_stream_reader.receive()
|
||||
except TimeoutError:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=httpx.codes.REQUEST_TIMEOUT,
|
||||
message=(
|
||||
f"Timed out while waiting for response to "
|
||||
f"{request.__class__.__name__}. Waited "
|
||||
f"{timeout} seconds."
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(response_or_error, JSONRPCError):
|
||||
raise McpError(response_or_error.error)
|
||||
else:
|
||||
return result_type.model_validate(response_or_error.result)
|
||||
|
||||
finally:
|
||||
self._response_streams.pop(request_id, None)
|
||||
self._progress_callbacks.pop(request_id, None)
|
||||
await response_stream.aclose()
|
||||
await response_stream_reader.aclose()
|
||||
|
||||
async def send_notification(
|
||||
self,
|
||||
notification: SendNotificationT,
|
||||
related_request_id: RequestId | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Emits a notification, which is a one-way message that does not expect
|
||||
a response.
|
||||
"""
|
||||
# Some transport implementations may need to set the related_request_id
|
||||
# to attribute to the notifications to the request that triggered them.
|
||||
jsonrpc_notification = JSONRPCNotification(
|
||||
jsonrpc="2.0",
|
||||
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage( # pragma: no cover
|
||||
message=JSONRPCMessage(jsonrpc_notification),
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
|
||||
)
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
|
||||
if isinstance(response, ErrorData):
|
||||
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
|
||||
await self._write_stream.send(session_message)
|
||||
else:
|
||||
jsonrpc_response = JSONRPCResponse(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
async with (
|
||||
self._read_stream,
|
||||
self._write_stream,
|
||||
):
|
||||
try:
|
||||
async for message in self._read_stream:
|
||||
if isinstance(message, Exception): # pragma: no cover
|
||||
await self._handle_incoming(message)
|
||||
elif isinstance(message.message.root, JSONRPCRequest):
|
||||
try:
|
||||
validated_request = self._receive_request_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
responder = RequestResponder(
|
||||
request_id=message.message.root.id,
|
||||
request_meta=validated_request.root.params.meta
|
||||
if validated_request.root.params
|
||||
else None,
|
||||
request=validated_request,
|
||||
session=self,
|
||||
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
|
||||
message_metadata=message.metadata,
|
||||
)
|
||||
self._in_flight[responder.request_id] = responder
|
||||
await self._received_request(responder)
|
||||
|
||||
if not responder._completed: # type: ignore[reportPrivateUsage]
|
||||
await self._handle_incoming(responder)
|
||||
except Exception as e:
|
||||
# For request validation errors, send a proper JSON-RPC error
|
||||
# response instead of crashing the server
|
||||
logging.warning(f"Failed to validate request: {e}")
|
||||
logging.debug(f"Message that failed validation: {message.message.root}")
|
||||
error_response = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=message.message.root.id,
|
||||
error=ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message="Invalid request parameters",
|
||||
data="",
|
||||
),
|
||||
)
|
||||
session_message = SessionMessage(message=JSONRPCMessage(error_response))
|
||||
await self._write_stream.send(session_message)
|
||||
|
||||
elif isinstance(message.message.root, JSONRPCNotification):
|
||||
try:
|
||||
notification = self._receive_notification_type.model_validate(
|
||||
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
)
|
||||
# Handle cancellation notifications
|
||||
if isinstance(notification.root, CancelledNotification):
|
||||
cancelled_id = notification.root.params.requestId
|
||||
if cancelled_id in self._in_flight: # pragma: no branch
|
||||
await self._in_flight[cancelled_id].cancel()
|
||||
else:
|
||||
# Handle progress notifications callback
|
||||
if isinstance(notification.root, ProgressNotification): # pragma: no cover
|
||||
progress_token = notification.root.params.progressToken
|
||||
# If there is a progress callback for this token,
|
||||
# call it with the progress information
|
||||
if progress_token in self._progress_callbacks:
|
||||
callback = self._progress_callbacks[progress_token]
|
||||
try:
|
||||
await callback(
|
||||
notification.root.params.progress,
|
||||
notification.root.params.total,
|
||||
notification.root.params.message,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Progress callback raised an exception: %s",
|
||||
e,
|
||||
)
|
||||
await self._received_notification(notification)
|
||||
await self._handle_incoming(notification)
|
||||
except Exception as e: # pragma: no cover
|
||||
# For other validation errors, log and continue
|
||||
logging.warning(
|
||||
f"Failed to validate notification: {e}. Message was: {message.message.root}"
|
||||
)
|
||||
else: # Response or error
|
||||
await self._handle_response(message)
|
||||
|
||||
except anyio.ClosedResourceError:
|
||||
# This is expected when the client disconnects abruptly.
|
||||
# Without this handler, the exception would propagate up and
|
||||
# crash the server's task group.
|
||||
logging.debug("Read stream closed by client") # pragma: no cover
|
||||
except Exception as e: # pragma: no cover
|
||||
# Other exceptions are not expected and should be logged. We purposefully
|
||||
# catch all exceptions here to avoid crashing the server.
|
||||
logging.exception(f"Unhandled exception in receive loop: {e}")
|
||||
finally:
|
||||
# after the read stream is closed, we need to send errors
|
||||
# to any pending requests
|
||||
for id, stream in self._response_streams.items():
|
||||
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
|
||||
try:
|
||||
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
|
||||
await stream.aclose()
|
||||
except Exception: # pragma: no cover
|
||||
# Stream might already be closed
|
||||
pass
|
||||
self._response_streams.clear()
|
||||
|
||||
def _normalize_request_id(self, response_id: RequestId) -> RequestId:
|
||||
"""
|
||||
Normalize a response ID to match how request IDs are stored.
|
||||
|
||||
Since the client always sends integer IDs, we normalize string IDs
|
||||
to integers when possible. This matches the TypeScript SDK approach:
|
||||
https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861
|
||||
|
||||
Args:
|
||||
response_id: The response ID from the incoming message.
|
||||
|
||||
Returns:
|
||||
The normalized ID (int if possible, otherwise original value).
|
||||
"""
|
||||
if isinstance(response_id, str):
|
||||
try:
|
||||
return int(response_id)
|
||||
except ValueError:
|
||||
logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests")
|
||||
return response_id
|
||||
|
||||
async def _handle_response(self, message: SessionMessage) -> None:
|
||||
"""
|
||||
Handle an incoming response or error message.
|
||||
|
||||
Checks response routers first (e.g., for task-related responses),
|
||||
then falls back to the normal response stream mechanism.
|
||||
"""
|
||||
root = message.message.root
|
||||
|
||||
# This check is always true at runtime: the caller (_receive_loop) only invokes
|
||||
# this method in the else branch after checking for JSONRPCRequest and
|
||||
# JSONRPCNotification. However, the type checker can't infer this from the
|
||||
# method signature, so we need this guard for type narrowing.
|
||||
if not isinstance(root, JSONRPCResponse | JSONRPCError):
|
||||
return # pragma: no cover
|
||||
|
||||
# Normalize response ID to handle type mismatches (e.g., "0" vs 0)
|
||||
response_id = self._normalize_request_id(root.id)
|
||||
|
||||
# First, check response routers (e.g., TaskResultHandler)
|
||||
if isinstance(root, JSONRPCError):
|
||||
# Route error to routers
|
||||
for router in self._response_routers:
|
||||
if router.route_error(response_id, root.error):
|
||||
return # Handled
|
||||
else:
|
||||
# Route success response to routers
|
||||
response_data: dict[str, Any] = root.result or {}
|
||||
for router in self._response_routers:
|
||||
if router.route_response(response_id, response_data):
|
||||
return # Handled
|
||||
|
||||
# Fall back to normal response streams
|
||||
stream = self._response_streams.pop(response_id, None)
|
||||
if stream: # pragma: no cover
|
||||
await stream.send(root)
|
||||
else: # pragma: no cover
|
||||
await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
|
||||
|
||||
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a request without needing to
|
||||
listen on the message stream.
|
||||
|
||||
If the request is responded to within this method, it will not be
|
||||
forwarded on to the message stream.
|
||||
"""
|
||||
|
||||
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
|
||||
"""
|
||||
Can be overridden by subclasses to handle a notification without needing
|
||||
to listen on the message stream.
|
||||
"""
|
||||
|
||||
async def send_progress_notification(
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sends a progress notification for a request that is currently being
|
||||
processed.
|
||||
"""
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
|
||||
) -> None:
|
||||
"""A generic handler for incoming messages. Overwritten by subclasses."""
|
||||
pass # pragma: no cover
|
||||
@@ -0,0 +1,129 @@
|
||||
"""Tool name validation utilities according to SEP-986.
|
||||
|
||||
Tool names SHOULD be between 1 and 128 characters in length (inclusive).
|
||||
Tool names are case-sensitive.
|
||||
Allowed characters: uppercase and lowercase ASCII letters (A-Z, a-z),
|
||||
digits (0-9), underscore (_), dash (-), and dot (.).
|
||||
Tool names SHOULD NOT contain spaces, commas, or other special characters.
|
||||
|
||||
See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regular expression for valid tool names according to SEP-986 specification
|
||||
TOOL_NAME_REGEX = re.compile(r"^[A-Za-z0-9._-]{1,128}$")
|
||||
|
||||
# SEP reference URL for warning messages
|
||||
SEP_986_URL = "https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolNameValidationResult:
|
||||
"""Result of tool name validation.
|
||||
|
||||
Attributes:
|
||||
is_valid: Whether the tool name conforms to SEP-986 requirements.
|
||||
warnings: List of warning messages for non-conforming aspects.
|
||||
"""
|
||||
|
||||
is_valid: bool
|
||||
warnings: list[str] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
def validate_tool_name(name: str) -> ToolNameValidationResult:
|
||||
"""Validate a tool name according to the SEP-986 specification.
|
||||
|
||||
Args:
|
||||
name: The tool name to validate.
|
||||
|
||||
Returns:
|
||||
ToolNameValidationResult containing validation status and any warnings.
|
||||
"""
|
||||
warnings: list[str] = []
|
||||
|
||||
# Check for empty name
|
||||
if not name:
|
||||
return ToolNameValidationResult(
|
||||
is_valid=False,
|
||||
warnings=["Tool name cannot be empty"],
|
||||
)
|
||||
|
||||
# Check length
|
||||
if len(name) > 128:
|
||||
return ToolNameValidationResult(
|
||||
is_valid=False,
|
||||
warnings=[f"Tool name exceeds maximum length of 128 characters (current: {len(name)})"],
|
||||
)
|
||||
|
||||
# Check for problematic patterns (warnings, not validation failures)
|
||||
if " " in name:
|
||||
warnings.append("Tool name contains spaces, which may cause parsing issues")
|
||||
|
||||
if "," in name:
|
||||
warnings.append("Tool name contains commas, which may cause parsing issues")
|
||||
|
||||
# Check for potentially confusing leading/trailing characters
|
||||
if name.startswith("-") or name.endswith("-"):
|
||||
warnings.append("Tool name starts or ends with a dash, which may cause parsing issues in some contexts")
|
||||
|
||||
if name.startswith(".") or name.endswith("."):
|
||||
warnings.append("Tool name starts or ends with a dot, which may cause parsing issues in some contexts")
|
||||
|
||||
# Check for invalid characters
|
||||
if not TOOL_NAME_REGEX.match(name):
|
||||
# Find all invalid characters (unique, preserving order)
|
||||
invalid_chars: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for char in name:
|
||||
if not re.match(r"[A-Za-z0-9._-]", char) and char not in seen:
|
||||
invalid_chars.append(char)
|
||||
seen.add(char)
|
||||
|
||||
warnings.append(f"Tool name contains invalid characters: {', '.join(repr(c) for c in invalid_chars)}")
|
||||
warnings.append("Allowed characters are: A-Z, a-z, 0-9, underscore (_), dash (-), and dot (.)")
|
||||
|
||||
return ToolNameValidationResult(is_valid=False, warnings=warnings)
|
||||
|
||||
return ToolNameValidationResult(is_valid=True, warnings=warnings)
|
||||
|
||||
|
||||
def issue_tool_name_warning(name: str, warnings: list[str]) -> None:
|
||||
"""Log warnings for non-conforming tool names.
|
||||
|
||||
Args:
|
||||
name: The tool name that triggered the warnings.
|
||||
warnings: List of warning messages to log.
|
||||
"""
|
||||
if not warnings:
|
||||
return
|
||||
|
||||
logger.warning(f'Tool name validation warning for "{name}":')
|
||||
for warning in warnings:
|
||||
logger.warning(f" - {warning}")
|
||||
logger.warning("Tool registration will proceed, but this may cause compatibility issues.")
|
||||
logger.warning("Consider updating the tool name to conform to the MCP tool naming standard.")
|
||||
logger.warning(f"See SEP-986 ({SEP_986_URL}) for more details.")
|
||||
|
||||
|
||||
def validate_and_warn_tool_name(name: str) -> bool:
|
||||
"""Validate a tool name and issue warnings for non-conforming names.
|
||||
|
||||
This is the primary entry point for tool name validation. It validates
|
||||
the name and logs any warnings via the logging module.
|
||||
|
||||
Args:
|
||||
name: The tool name to validate.
|
||||
|
||||
Returns:
|
||||
True if the name is valid, False otherwise.
|
||||
"""
|
||||
result = validate_tool_name(name)
|
||||
issue_tool_name_warning(name, result.warnings)
|
||||
return result.is_valid
|
||||
3
.venv/lib/python3.11/site-packages/mcp/shared/version.py
Normal file
3
.venv/lib/python3.11/site-packages/mcp/shared/version.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", "2025-06-18", LATEST_PROTOCOL_VERSION]
|
||||
Reference in New Issue
Block a user