Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,87 @@
"""Utilities for creating standardized httpx AsyncClient instances."""
from typing import Any, Protocol
import httpx
__all__ = ["create_mcp_http_client", "MCP_DEFAULT_TIMEOUT", "MCP_DEFAULT_SSE_READ_TIMEOUT"]
# Default MCP timeout configuration
MCP_DEFAULT_TIMEOUT = 30.0 # General operations (seconds)
MCP_DEFAULT_SSE_READ_TIMEOUT = 300.0 # SSE streams - 5 minutes (seconds)
class McpHttpClientFactory(Protocol): # pragma: no branch
def __call__( # pragma: no branch
self,
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient: ...
def create_mcp_http_client(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
"""Create a standardized httpx AsyncClient with MCP defaults.
This function provides common defaults used throughout the MCP codebase:
- follow_redirects=True (always enabled)
- Default timeout of 30 seconds if not specified
Args:
headers: Optional headers to include with all requests.
timeout: Request timeout as httpx.Timeout object.
Defaults to 30 seconds if not specified.
auth: Optional authentication handler.
Returns:
Configured httpx.AsyncClient instance with MCP defaults.
Note:
The returned AsyncClient must be used as a context manager to ensure
proper cleanup of connections.
Examples:
# Basic usage with MCP defaults
async with create_mcp_http_client() as client:
response = await client.get("https://api.example.com")
# With custom headers
headers = {"Authorization": "Bearer token"}
async with create_mcp_http_client(headers) as client:
response = await client.get("/endpoint")
# With both custom headers and timeout
timeout = httpx.Timeout(60.0, read=300.0)
async with create_mcp_http_client(headers, timeout) as client:
response = await client.get("/long-request")
# With authentication
from httpx import BasicAuth
auth = BasicAuth(username="user", password="pass")
async with create_mcp_http_client(headers, timeout, auth) as client:
response = await client.get("/protected-endpoint")
"""
# Set MCP defaults
kwargs: dict[str, Any] = {
"follow_redirects": True,
}
# Handle timeout
if timeout is None:
kwargs["timeout"] = httpx.Timeout(MCP_DEFAULT_TIMEOUT, read=MCP_DEFAULT_SSE_READ_TIMEOUT)
else:
kwargs["timeout"] = timeout
# Handle headers
if headers is not None:
kwargs["headers"] = headers
# Handle authentication
if auth is not None: # pragma: no cover
kwargs["auth"] = auth
return httpx.AsyncClient(**kwargs)

View File

@@ -0,0 +1,159 @@
from typing import Any, Literal
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator
class OAuthToken(BaseModel):
"""
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
"""
access_token: str
token_type: Literal["Bearer"] = "Bearer"
expires_in: int | None = None
scope: str | None = None
refresh_token: str | None = None
@field_validator("token_type", mode="before")
@classmethod
def normalize_token_type(cls, v: str | None) -> str | None:
if isinstance(v, str):
# Bearer is title-cased in the spec, so we normalize it
# https://datatracker.ietf.org/doc/html/rfc6750#section-4
return v.title()
return v # pragma: no cover
class InvalidScopeError(Exception):
def __init__(self, message: str):
self.message = message
class InvalidRedirectUriError(Exception):
def __init__(self, message: str):
self.message = message
class OAuthClientMetadata(BaseModel):
"""
RFC 7591 OAuth 2.0 Dynamic Client Registration metadata.
See https://datatracker.ietf.org/doc/html/rfc7591#section-2
for the full specification.
"""
redirect_uris: list[AnyUrl] | None = Field(..., min_length=1)
# supported auth methods for the token endpoint
token_endpoint_auth_method: (
Literal["none", "client_secret_post", "client_secret_basic", "private_key_jwt"] | None
) = None
# supported grant_types of this implementation
grant_types: list[
Literal["authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:jwt-bearer"] | str
] = [
"authorization_code",
"refresh_token",
]
# The MCP spec requires the "code" response type, but OAuth
# servers may also return additional types they support
response_types: list[str] = ["code"]
scope: str | None = None
# these fields are currently unused, but we support & store them for potential
# future use
client_name: str | None = None
client_uri: AnyHttpUrl | None = None
logo_uri: AnyHttpUrl | None = None
contacts: list[str] | None = None
tos_uri: AnyHttpUrl | None = None
policy_uri: AnyHttpUrl | None = None
jwks_uri: AnyHttpUrl | None = None
jwks: Any | None = None
software_id: str | None = None
software_version: str | None = None
def validate_scope(self, requested_scope: str | None) -> list[str] | None:
if requested_scope is None:
return None
requested_scopes = requested_scope.split(" ")
allowed_scopes = [] if self.scope is None else self.scope.split(" ")
for scope in requested_scopes:
if scope not in allowed_scopes: # pragma: no branch
raise InvalidScopeError(f"Client was not registered with scope {scope}")
return requested_scopes # pragma: no cover
def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
if redirect_uri is not None:
# Validate redirect_uri against client's registered redirect URIs
if self.redirect_uris is None or redirect_uri not in self.redirect_uris:
raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client")
return redirect_uri
elif self.redirect_uris is not None and len(self.redirect_uris) == 1:
return self.redirect_uris[0]
else:
raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs")
class OAuthClientInformationFull(OAuthClientMetadata):
"""
RFC 7591 OAuth 2.0 Dynamic Client Registration full response
(client information plus metadata).
"""
client_id: str | None = None
client_secret: str | None = None
client_id_issued_at: int | None = None
client_secret_expires_at: int | None = None
class OAuthMetadata(BaseModel):
"""
RFC 8414 OAuth 2.0 Authorization Server Metadata.
See https://datatracker.ietf.org/doc/html/rfc8414#section-2
"""
issuer: AnyHttpUrl
authorization_endpoint: AnyHttpUrl
token_endpoint: AnyHttpUrl
registration_endpoint: AnyHttpUrl | None = None
scopes_supported: list[str] | None = None
response_types_supported: list[str] = ["code"]
response_modes_supported: list[str] | None = None
grant_types_supported: list[str] | None = None
token_endpoint_auth_methods_supported: list[str] | None = None
token_endpoint_auth_signing_alg_values_supported: list[str] | None = None
service_documentation: AnyHttpUrl | None = None
ui_locales_supported: list[str] | None = None
op_policy_uri: AnyHttpUrl | None = None
op_tos_uri: AnyHttpUrl | None = None
revocation_endpoint: AnyHttpUrl | None = None
revocation_endpoint_auth_methods_supported: list[str] | None = None
revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None
introspection_endpoint: AnyHttpUrl | None = None
introspection_endpoint_auth_methods_supported: list[str] | None = None
introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None
code_challenge_methods_supported: list[str] | None = None
client_id_metadata_document_supported: bool | None = None
class ProtectedResourceMetadata(BaseModel):
"""
RFC 9728 OAuth 2.0 Protected Resource Metadata.
See https://datatracker.ietf.org/doc/html/rfc9728#section-2
"""
resource: AnyHttpUrl
authorization_servers: list[AnyHttpUrl] = Field(..., min_length=1)
jwks_uri: AnyHttpUrl | None = None
scopes_supported: list[str] | None = None
bearer_methods_supported: list[str] | None = Field(default=["header"]) # MCP only supports header method
resource_signing_alg_values_supported: list[str] | None = None
resource_name: str | None = None
resource_documentation: AnyHttpUrl | None = None
resource_policy_uri: AnyHttpUrl | None = None
resource_tos_uri: AnyHttpUrl | None = None
# tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity
tls_client_certificate_bound_access_tokens: bool | None = None
authorization_details_types_supported: list[str] | None = None
dpop_signing_alg_values_supported: list[str] | None = None
# dpop_bound_access_tokens_required default is False, but ommited here for clarity
dpop_bound_access_tokens_required: bool | None = None

View File

@@ -0,0 +1,85 @@
"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636)."""
import time
from urllib.parse import urlparse, urlsplit, urlunsplit
from pydantic import AnyUrl, HttpUrl
def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str:
"""Convert server URL to canonical resource URL per RFC 8707.
RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component".
Returns absolute URI with lowercase scheme/host for canonical form.
Args:
url: Server URL to convert
Returns:
Canonical resource URL string
"""
# Convert to string if needed
url_str = str(url)
# Parse the URL and remove fragment, create canonical form
parsed = urlsplit(url_str)
canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment=""))
return canonical
def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool:
"""Check if a requested resource URL matches a configured resource URL.
A requested resource matches if it has the same scheme, domain, port,
and its path starts with the configured resource's path. This allows
hierarchical matching where a token for a parent resource can be used
for child resources.
Args:
requested_resource: The resource URL being requested
configured_resource: The resource URL that has been configured
Returns:
True if the requested resource matches the configured resource
"""
# Parse both URLs
requested = urlparse(requested_resource)
configured = urlparse(configured_resource)
# Compare scheme, host, and port (origin)
if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower():
return False
# Handle cases like requested=/foo and configured=/foo/
requested_path = requested.path
configured_path = configured.path
# If requested path is shorter, it cannot be a child
if len(requested_path) < len(configured_path):
return False
# Check if the requested path starts with the configured path
# Ensure both paths end with / for proper comparison
# This ensures that paths like "/api123" don't incorrectly match "/api"
if not requested_path.endswith("/"):
requested_path += "/"
if not configured_path.endswith("/"):
configured_path += "/"
return requested_path.startswith(configured_path)
def calculate_token_expiry(expires_in: int | str | None) -> float | None:
"""Calculate token expiry timestamp from expires_in seconds.
Args:
expires_in: Seconds until token expiration (may be string from some servers)
Returns:
Unix timestamp when token expires, or None if no expiry specified
"""
if expires_in is None:
return None # pragma: no cover
# Defensive: handle servers that return expires_in as string
return time.time() + int(expires_in)

View File

@@ -0,0 +1,32 @@
"""
Request context for MCP handlers.
"""
from dataclasses import dataclass, field
from typing import Any, Generic
from typing_extensions import TypeVar
from mcp.shared.message import CloseSSEStreamCallback
from mcp.shared.session import BaseSession
from mcp.types import RequestId, RequestParams
SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
LifespanContextT = TypeVar("LifespanContextT")
RequestT = TypeVar("RequestT", default=Any)
@dataclass
class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
request_id: RequestId
meta: RequestParams.Meta | None
session: SessionT
lifespan_context: LifespanContextT
# NOTE: This is typed as Any to avoid circular imports. The actual type is
# mcp.server.experimental.request_context.Experimental, but importing it here
# triggers mcp.server.__init__ -> fastmcp -> tools -> back to this module.
# The Server sets this to an Experimental instance at runtime.
experimental: Any = field(default=None)
request: RequestT | None = None
close_sse_stream: CloseSSEStreamCallback | None = None
close_standalone_sse_stream: CloseSSEStreamCallback | None = None

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Any, cast
from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData
class McpError(Exception):
"""
Exception type raised when an error arrives over an MCP connection.
"""
error: ErrorData
def __init__(self, error: ErrorData):
"""Initialize McpError."""
super().__init__(error.message)
self.error = error
class UrlElicitationRequiredError(McpError):
"""
Specialized error for when a tool requires URL mode elicitation(s) before proceeding.
Servers can raise this error from tool handlers to indicate that the client
must complete one or more URL elicitations before the request can be processed.
Example:
raise UrlElicitationRequiredError([
ElicitRequestURLParams(
mode="url",
message="Authorization required for your files",
url="https://example.com/oauth/authorize",
elicitationId="auth-001"
)
])
"""
def __init__(
self,
elicitations: list[ElicitRequestURLParams],
message: str | None = None,
):
"""Initialize UrlElicitationRequiredError."""
if message is None:
message = f"URL elicitation{'s' if len(elicitations) > 1 else ''} required"
self._elicitations = elicitations
error = ErrorData(
code=URL_ELICITATION_REQUIRED,
message=message,
data={"elicitations": [e.model_dump(by_alias=True, exclude_none=True) for e in elicitations]},
)
super().__init__(error)
@property
def elicitations(self) -> list[ElicitRequestURLParams]:
"""The list of URL elicitations required before the request can proceed."""
return self._elicitations
@classmethod
def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError:
"""Reconstruct from an ErrorData received over the wire."""
if error.code != URL_ELICITATION_REQUIRED:
raise ValueError(f"Expected error code {URL_ELICITATION_REQUIRED}, got {error.code}")
data = cast(dict[str, Any], error.data or {})
raw_elicitations = cast(list[dict[str, Any]], data.get("elicitations", []))
elicitations = [ElicitRequestURLParams.model_validate(e) for e in raw_elicitations]
return cls(elicitations, error.message)

View File

@@ -0,0 +1,7 @@
"""
Pure experimental MCP features (no server dependencies).
WARNING: These APIs are experimental and may change without notice.
For server-integrated experimental features, use mcp.server.experimental.
"""

View File

@@ -0,0 +1,12 @@
"""
Pure task state management for MCP.
WARNING: These APIs are experimental and may change without notice.
Import directly from submodules:
- mcp.shared.experimental.tasks.store.TaskStore
- mcp.shared.experimental.tasks.context.TaskContext
- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore
- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue
- mcp.shared.experimental.tasks.helpers.is_terminal
"""

View File

@@ -0,0 +1,115 @@
"""
Tasks capability checking utilities.
This module provides functions for checking and requiring task-related
capabilities. All tasks capability logic is centralized here to keep
the main session code clean.
WARNING: These APIs are experimental and may change without notice.
"""
from mcp.shared.exceptions import McpError
from mcp.types import (
INVALID_REQUEST,
ClientCapabilities,
ClientTasksCapability,
ErrorData,
)
def check_tasks_capability(
required: ClientTasksCapability,
client: ClientTasksCapability,
) -> bool:
"""
Check if client's tasks capability matches the required capability.
Args:
required: The capability being checked for
client: The client's declared capabilities
Returns:
True if client has the required capability, False otherwise
"""
if required.requests is None:
return True
if client.requests is None:
return False
# Check elicitation.create
if required.requests.elicitation is not None:
if client.requests.elicitation is None:
return False
if required.requests.elicitation.create is not None:
if client.requests.elicitation.create is None:
return False
# Check sampling.createMessage
if required.requests.sampling is not None:
if client.requests.sampling is None:
return False
if required.requests.sampling.createMessage is not None:
if client.requests.sampling.createMessage is None:
return False
return True
def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool:
"""Check if capabilities include task-augmented elicitation support."""
if caps.tasks is None:
return False
if caps.tasks.requests is None:
return False
if caps.tasks.requests.elicitation is None:
return False
return caps.tasks.requests.elicitation.create is not None
def has_task_augmented_sampling(caps: ClientCapabilities) -> bool:
"""Check if capabilities include task-augmented sampling support."""
if caps.tasks is None:
return False
if caps.tasks.requests is None:
return False
if caps.tasks.requests.sampling is None:
return False
return caps.tasks.requests.sampling.createMessage is not None
def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None:
"""
Raise McpError if client doesn't support task-augmented elicitation.
Args:
client_caps: The client's declared capabilities, or None if not initialized
Raises:
McpError: If client doesn't support task-augmented elicitation
"""
if client_caps is None or not has_task_augmented_elicitation(client_caps):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support task-augmented elicitation",
)
)
def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None:
"""
Raise McpError if client doesn't support task-augmented sampling.
Args:
client_caps: The client's declared capabilities, or None if not initialized
Raises:
McpError: If client doesn't support task-augmented sampling
"""
if client_caps is None or not has_task_augmented_sampling(client_caps):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support task-augmented sampling",
)
)

View File

@@ -0,0 +1,101 @@
"""
TaskContext - Pure task state management.
This module provides TaskContext, which manages task state without any
server/session dependencies. It can be used standalone for distributed
workers or wrapped by ServerTaskContext for full server integration.
"""
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task
class TaskContext:
"""
Pure task state management - no session dependencies.
This class handles:
- Task state (status, result)
- Cancellation tracking
- Store interactions
For server-integrated features (elicit, create_message, notifications),
use ServerTaskContext from mcp.server.experimental.
Example (distributed worker):
async def worker_job(task_id: str):
store = RedisTaskStore(redis_url)
task = await store.get_task(task_id)
ctx = TaskContext(task=task, store=store)
await ctx.update_status("Working...")
result = await do_work()
await ctx.complete(result)
"""
def __init__(self, task: Task, store: TaskStore):
self._task = task
self._store = store
self._cancelled = False
@property
def task_id(self) -> str:
"""The task identifier."""
return self._task.taskId
@property
def task(self) -> Task:
"""The current task state."""
return self._task
@property
def is_cancelled(self) -> bool:
"""Whether cancellation has been requested."""
return self._cancelled
def request_cancellation(self) -> None:
"""
Request cancellation of this task.
This sets is_cancelled=True. Task work should check this
periodically and exit gracefully if set.
"""
self._cancelled = True
async def update_status(self, message: str) -> None:
"""
Update the task's status message.
Args:
message: The new status message
"""
self._task = await self._store.update_task(
self.task_id,
status_message=message,
)
async def complete(self, result: Result) -> None:
"""
Mark the task as completed with the given result.
Args:
result: The task result
"""
await self._store.store_result(self.task_id, result)
self._task = await self._store.update_task(
self.task_id,
status=TASK_STATUS_COMPLETED,
)
async def fail(self, error: str) -> None:
"""
Mark the task as failed with an error message.
Args:
error: The error message
"""
self._task = await self._store.update_task(
self.task_id,
status=TASK_STATUS_FAILED,
status_message=error,
)

View File

@@ -0,0 +1,181 @@
"""
Helper functions for pure task management.
These helpers work with pure TaskContext and don't require server dependencies.
For server-integrated task helpers, use mcp.server.experimental.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from uuid import uuid4
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.context import TaskContext
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import (
INVALID_PARAMS,
TASK_STATUS_CANCELLED,
TASK_STATUS_COMPLETED,
TASK_STATUS_FAILED,
TASK_STATUS_WORKING,
CancelTaskResult,
ErrorData,
Task,
TaskMetadata,
TaskStatus,
)
# Metadata key for model-immediate-response (per MCP spec)
# Servers MAY include this in CreateTaskResult._meta to provide an immediate
# response string while the task executes in the background.
MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response"
# Metadata key for associating requests with a task (per MCP spec)
RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task"
def is_terminal(status: TaskStatus) -> bool:
"""
Check if a task status represents a terminal state.
Terminal states are those where the task has finished and will not change.
Args:
status: The task status to check
Returns:
True if the status is terminal (completed, failed, or cancelled)
"""
return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED)
async def cancel_task(
store: TaskStore,
task_id: str,
) -> CancelTaskResult:
"""
Cancel a task with spec-compliant validation.
Per spec: "Receivers MUST reject cancellation of terminal status tasks
with -32602 (Invalid params)"
This helper validates that the task exists and is not in a terminal state
before setting it to "cancelled".
Args:
store: The task store
task_id: The task identifier to cancel
Returns:
CancelTaskResult with the cancelled task state
Raises:
McpError: With INVALID_PARAMS (-32602) if:
- Task does not exist
- Task is already in a terminal state (completed, failed, cancelled)
Example:
@server.experimental.cancel_task()
async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult:
return await cancel_task(store, request.params.taskId)
"""
task = await store.get_task(task_id)
if task is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task not found: {task_id}",
)
)
if is_terminal(task.status):
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Cannot cancel task in terminal state '{task.status}'",
)
)
# Update task to cancelled status
cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED)
return CancelTaskResult(**cancelled_task.model_dump())
def generate_task_id() -> str:
"""Generate a unique task ID."""
return str(uuid4())
def create_task_state(
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""
Create a Task object with initial state.
This is a helper for TaskStore implementations.
Args:
metadata: Task metadata
task_id: Optional task ID (generated if not provided)
Returns:
A new Task in "working" status
"""
now = datetime.now(timezone.utc)
return Task(
taskId=task_id or generate_task_id(),
status=TASK_STATUS_WORKING,
createdAt=now,
lastUpdatedAt=now,
ttl=metadata.ttl,
pollInterval=500, # Default 500ms poll interval
)
@asynccontextmanager
async def task_execution(
task_id: str,
store: TaskStore,
) -> AsyncIterator[TaskContext]:
"""
Context manager for safe task execution (pure, no server dependencies).
Loads a task from the store and provides a TaskContext for the work.
If an unhandled exception occurs, the task is automatically marked as failed
and the exception is suppressed (since the failure is captured in task state).
This is useful for distributed workers that don't have a server session.
Args:
task_id: The task identifier to execute
store: The task store (must be accessible by the worker)
Yields:
TaskContext for updating status and completing/failing the task
Raises:
ValueError: If the task is not found in the store
Example (distributed worker):
async def worker_process(task_id: str):
store = RedisTaskStore(redis_url)
async with task_execution(task_id, store) as ctx:
await ctx.update_status("Working...")
result = await do_work()
await ctx.complete(result)
"""
task = await store.get_task(task_id)
if task is None:
raise ValueError(f"Task {task_id} not found")
ctx = TaskContext(task, store)
try:
yield ctx
except Exception as e:
# Auto-fail the task if an exception occurs and task isn't already terminal
# Exception is suppressed since failure is captured in task state
if not is_terminal(ctx.task.status):
await ctx.fail(str(e))
# Don't re-raise - the failure is recorded in task state

View File

@@ -0,0 +1,219 @@
"""
In-memory implementation of TaskStore for demonstration purposes.
This implementation stores all tasks in memory and provides automatic cleanup
based on the TTL duration specified in the task metadata using lazy expiration.
Note: This is not suitable for production use as all data is lost on restart.
For production, consider implementing TaskStore with a database or distributed cache.
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
import anyio
from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import Result, Task, TaskMetadata, TaskStatus
@dataclass
class StoredTask:
"""Internal storage representation of a task."""
task: Task
result: Result | None = None
# Time when this task should be removed (None = never)
expires_at: datetime | None = field(default=None)
class InMemoryTaskStore(TaskStore):
"""
A simple in-memory implementation of TaskStore.
Features:
- Automatic TTL-based cleanup (lazy expiration)
- Thread-safe for single-process async use
- Pagination support for list_tasks
Limitations:
- All data lost on restart
- Not suitable for distributed systems
- No persistence
For production, implement TaskStore with Redis, PostgreSQL, etc.
"""
def __init__(self, page_size: int = 10) -> None:
self._tasks: dict[str, StoredTask] = {}
self._page_size = page_size
self._update_events: dict[str, anyio.Event] = {}
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
"""Calculate expiry time from TTL in milliseconds."""
if ttl_ms is None:
return None
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
def _is_expired(self, stored: StoredTask) -> bool:
"""Check if a task has expired."""
if stored.expires_at is None:
return False
return datetime.now(timezone.utc) >= stored.expires_at
def _cleanup_expired(self) -> None:
"""Remove all expired tasks. Called lazily during access operations."""
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
for task_id in expired_ids:
del self._tasks[task_id]
async def create_task(
self,
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""Create a new task with the given metadata."""
# Cleanup expired tasks on access
self._cleanup_expired()
task = create_task_state(metadata, task_id)
if task.taskId in self._tasks:
raise ValueError(f"Task with ID {task.taskId} already exists")
stored = StoredTask(
task=task,
expires_at=self._calculate_expiry(metadata.ttl),
)
self._tasks[task.taskId] = stored
# Return a copy to prevent external modification
return Task(**task.model_dump())
async def get_task(self, task_id: str) -> Task | None:
"""Get a task by ID."""
# Cleanup expired tasks on access
self._cleanup_expired()
stored = self._tasks.get(task_id)
if stored is None:
return None
# Return a copy to prevent external modification
return Task(**stored.task.model_dump())
async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""Update a task's status and/or message."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
# Per spec: Terminal states MUST NOT transition to any other status
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
status_changed = False
if status is not None and stored.task.status != status:
stored.task.status = status
status_changed = True
if status_message is not None:
stored.task.statusMessage = status_message
# Update lastUpdatedAt on any change
stored.task.lastUpdatedAt = datetime.now(timezone.utc)
# If task is now terminal and has TTL, reset expiry timer
if status is not None and is_terminal(status) and stored.task.ttl is not None:
stored.expires_at = self._calculate_expiry(stored.task.ttl)
# Notify waiters if status changed
if status_changed:
await self.notify_update(task_id)
return Task(**stored.task.model_dump())
async def store_result(self, task_id: str, result: Result) -> None:
"""Store the result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
stored.result = result
async def get_result(self, task_id: str) -> Result | None:
"""Get the stored result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
return None
return stored.result
async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""List tasks with pagination."""
# Cleanup expired tasks on access
self._cleanup_expired()
all_task_ids = list(self._tasks.keys())
start_index = 0
if cursor is not None:
try:
cursor_index = all_task_ids.index(cursor)
start_index = cursor_index + 1
except ValueError:
raise ValueError(f"Invalid cursor: {cursor}")
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
# Determine next cursor
next_cursor = None
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
next_cursor = page_task_ids[-1]
return tasks, next_cursor
async def delete_task(self, task_id: str) -> bool:
"""Delete a task."""
if task_id not in self._tasks:
return False
del self._tasks[task_id]
return True
async def wait_for_update(self, task_id: str) -> None:
"""Wait until the task status changes."""
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")
# Create a fresh event for waiting (anyio.Event can't be cleared)
self._update_events[task_id] = anyio.Event()
event = self._update_events[task_id]
await event.wait()
async def notify_update(self, task_id: str) -> None:
"""Signal that a task has been updated."""
if task_id in self._update_events:
self._update_events[task_id].set()
# --- Testing/debugging helpers ---
def cleanup(self) -> None:
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
self._tasks.clear()
self._update_events.clear()
def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
self._cleanup_expired()
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]

View File

@@ -0,0 +1,241 @@
"""
TaskMessageQueue - FIFO queue for task-related messages.
This implements the core message queue pattern from the MCP Tasks spec.
When a handler needs to send a request (like elicitation) during a task-augmented
request, the message is enqueued instead of sent directly. Messages are delivered
to the client only through the `tasks/result` endpoint.
This pattern enables:
1. Decoupling request handling from message delivery
2. Proper bidirectional communication via the tasks/result stream
3. Automatic status management (working <-> input_required)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import anyio
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId
@dataclass
class QueuedMessage:
"""
A message queued for delivery via tasks/result.
Messages are stored with their type and a resolver for requests
that expect responses.
"""
type: Literal["request", "notification"]
"""Whether this is a request (expects response) or notification (one-way)."""
message: JSONRPCRequest | JSONRPCNotification
"""The JSON-RPC message to send."""
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
"""When the message was enqueued."""
resolver: Resolver[dict[str, Any]] | None = None
"""Resolver to set when response arrives (only for requests)."""
original_request_id: RequestId | None = None
"""The original request ID used internally, for routing responses back."""
class TaskMessageQueue(ABC):
"""
Abstract interface for task message queuing.
This is a FIFO queue that stores messages to be delivered via `tasks/result`.
When a task-augmented handler calls elicit() or sends a notification, the
message is enqueued here instead of being sent directly to the client.
The `tasks/result` handler then dequeues and sends these messages through
the transport, with `relatedRequestId` set to the tasks/result request ID
so responses are routed correctly.
Implementations can use in-memory storage, Redis, etc.
"""
@abstractmethod
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
"""
Add a message to the queue for a task.
Args:
task_id: The task identifier
message: The message to enqueue
"""
@abstractmethod
async def dequeue(self, task_id: str) -> QueuedMessage | None:
"""
Remove and return the next message from the queue.
Args:
task_id: The task identifier
Returns:
The next message, or None if queue is empty
"""
@abstractmethod
async def peek(self, task_id: str) -> QueuedMessage | None:
"""
Return the next message without removing it.
Args:
task_id: The task identifier
Returns:
The next message, or None if queue is empty
"""
@abstractmethod
async def is_empty(self, task_id: str) -> bool:
"""
Check if the queue is empty for a task.
Args:
task_id: The task identifier
Returns:
True if no messages are queued
"""
@abstractmethod
async def clear(self, task_id: str) -> list[QueuedMessage]:
"""
Remove and return all messages from the queue.
This is useful for cleanup when a task is cancelled or completed.
Args:
task_id: The task identifier
Returns:
All queued messages (may be empty)
"""
@abstractmethod
async def wait_for_message(self, task_id: str) -> None:
"""
Wait until a message is available in the queue.
This blocks until either:
1. A message is enqueued for this task
2. The wait is cancelled
Args:
task_id: The task identifier
"""
@abstractmethod
async def notify_message_available(self, task_id: str) -> None:
"""
Signal that a message is available for a task.
This wakes up any coroutines waiting in wait_for_message().
Args:
task_id: The task identifier
"""
class InMemoryTaskMessageQueue(TaskMessageQueue):
"""
In-memory implementation of TaskMessageQueue.
This is suitable for single-process servers. For distributed systems,
implement TaskMessageQueue with Redis, RabbitMQ, etc.
Features:
- FIFO ordering per task
- Async wait for message availability
- Thread-safe for single-process async use
"""
def __init__(self) -> None:
self._queues: dict[str, list[QueuedMessage]] = {}
self._events: dict[str, anyio.Event] = {}
def _get_queue(self, task_id: str) -> list[QueuedMessage]:
"""Get or create the queue for a task."""
if task_id not in self._queues:
self._queues[task_id] = []
return self._queues[task_id]
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
"""Add a message to the queue."""
queue = self._get_queue(task_id)
queue.append(message)
# Signal that a message is available
await self.notify_message_available(task_id)
async def dequeue(self, task_id: str) -> QueuedMessage | None:
"""Remove and return the next message."""
queue = self._get_queue(task_id)
if not queue:
return None
return queue.pop(0)
async def peek(self, task_id: str) -> QueuedMessage | None:
"""Return the next message without removing it."""
queue = self._get_queue(task_id)
if not queue:
return None
return queue[0]
async def is_empty(self, task_id: str) -> bool:
"""Check if the queue is empty."""
queue = self._get_queue(task_id)
return len(queue) == 0
async def clear(self, task_id: str) -> list[QueuedMessage]:
"""Remove and return all messages."""
queue = self._get_queue(task_id)
messages = list(queue)
queue.clear()
return messages
async def wait_for_message(self, task_id: str) -> None:
"""Wait until a message is available."""
# Check if there are already messages
if not await self.is_empty(task_id):
return
# Create a fresh event for waiting (anyio.Event can't be cleared)
self._events[task_id] = anyio.Event()
event = self._events[task_id]
# Double-check after creating event (avoid race condition)
if not await self.is_empty(task_id):
return
# Wait for a new message
await event.wait()
async def notify_message_available(self, task_id: str) -> None:
"""Signal that a message is available."""
if task_id in self._events:
self._events[task_id].set()
def cleanup(self, task_id: str | None = None) -> None:
"""
Clean up queues and events.
Args:
task_id: If provided, clean up only this task. Otherwise clean up all.
"""
if task_id is not None:
self._queues.pop(task_id, None)
self._events.pop(task_id, None)
else:
self._queues.clear()
self._events.clear()

View File

@@ -0,0 +1,45 @@
"""
Shared polling utilities for task operations.
This module provides generic polling logic that works for both client→server
and server→client task polling.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import AsyncIterator, Awaitable, Callable
import anyio
from mcp.shared.experimental.tasks.helpers import is_terminal
from mcp.types import GetTaskResult
async def poll_until_terminal(
get_task: Callable[[str], Awaitable[GetTaskResult]],
task_id: str,
default_interval_ms: int = 500,
) -> AsyncIterator[GetTaskResult]:
"""
Poll a task until it reaches terminal status.
This is a generic utility that works for both client→server and server→client
polling. The caller provides the get_task function appropriate for their direction.
Args:
get_task: Async function that takes task_id and returns GetTaskResult
task_id: The task to poll
default_interval_ms: Fallback poll interval if server doesn't specify
Yields:
GetTaskResult for each poll
"""
while True:
status = await get_task(task_id)
yield status
if is_terminal(status.status):
break
interval_ms = status.pollInterval if status.pollInterval is not None else default_interval_ms
await anyio.sleep(interval_ms / 1000)

View File

@@ -0,0 +1,60 @@
"""
Resolver - An anyio-compatible future-like object for async result passing.
This provides a simple way to pass a result (or exception) from one coroutine
to another without depending on asyncio.Future.
"""
from typing import Generic, TypeVar, cast
import anyio
T = TypeVar("T")
class Resolver(Generic[T]):
"""
A simple resolver for passing results between coroutines.
Unlike asyncio.Future, this works with any anyio-compatible async backend.
Usage:
resolver: Resolver[str] = Resolver()
# In one coroutine:
resolver.set_result("hello")
# In another coroutine:
result = await resolver.wait() # returns "hello"
"""
def __init__(self) -> None:
self._event = anyio.Event()
self._value: T | None = None
self._exception: BaseException | None = None
def set_result(self, value: T) -> None:
"""Set the result value and wake up waiters."""
if self._event.is_set():
raise RuntimeError("Resolver already completed")
self._value = value
self._event.set()
def set_exception(self, exc: BaseException) -> None:
"""Set an exception and wake up waiters."""
if self._event.is_set():
raise RuntimeError("Resolver already completed")
self._exception = exc
self._event.set()
async def wait(self) -> T:
"""Wait for the result and return it, or raise the exception."""
await self._event.wait()
if self._exception is not None:
raise self._exception
# If we reach here, set_result() was called, so _value is set
return cast(T, self._value)
def done(self) -> bool:
"""Return True if the resolver has been completed."""
return self._event.is_set()

View File

@@ -0,0 +1,156 @@
"""
TaskStore - Abstract interface for task state storage.
"""
from abc import ABC, abstractmethod
from mcp.types import Result, Task, TaskMetadata, TaskStatus
class TaskStore(ABC):
"""
Abstract interface for task state storage.
This is a pure storage interface - it doesn't manage execution.
Implementations can use in-memory storage, databases, Redis, etc.
All methods are async to support various backends.
"""
@abstractmethod
async def create_task(
self,
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""
Create a new task.
Args:
metadata: Task metadata (ttl, etc.)
task_id: Optional task ID. If None, implementation should generate one.
Returns:
The created Task with status="working"
Raises:
ValueError: If task_id already exists
"""
@abstractmethod
async def get_task(self, task_id: str) -> Task | None:
"""
Get a task by ID.
Args:
task_id: The task identifier
Returns:
The Task, or None if not found
"""
@abstractmethod
async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""
Update a task's status and/or message.
Args:
task_id: The task identifier
status: New status (if changing)
status_message: New status message (if changing)
Returns:
The updated Task
Raises:
ValueError: If task not found
ValueError: If attempting to transition from a terminal status
(completed, failed, cancelled). Per spec, terminal states
MUST NOT transition to any other status.
"""
@abstractmethod
async def store_result(self, task_id: str, result: Result) -> None:
"""
Store the result for a task.
Args:
task_id: The task identifier
result: The result to store
Raises:
ValueError: If task not found
"""
@abstractmethod
async def get_result(self, task_id: str) -> Result | None:
"""
Get the stored result for a task.
Args:
task_id: The task identifier
Returns:
The stored Result, or None if not available
"""
@abstractmethod
async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""
List tasks with pagination.
Args:
cursor: Optional cursor for pagination
Returns:
Tuple of (tasks, next_cursor). next_cursor is None if no more pages.
"""
@abstractmethod
async def delete_task(self, task_id: str) -> bool:
"""
Delete a task.
Args:
task_id: The task identifier
Returns:
True if deleted, False if not found
"""
@abstractmethod
async def wait_for_update(self, task_id: str) -> None:
"""
Wait until the task status changes.
This blocks until either:
1. The task status changes
2. The wait is cancelled
Used by tasks/result to wait for task completion or status changes.
Args:
task_id: The task identifier
Raises:
ValueError: If task not found
"""
@abstractmethod
async def notify_update(self, task_id: str) -> None:
"""
Signal that a task has been updated.
This wakes up any coroutines waiting in wait_for_update().
Args:
task_id: The task identifier
"""

View File

@@ -0,0 +1,98 @@
"""
In-memory transports
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import timedelta
from typing import Any
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import ClientSession, ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.server import Server
from mcp.server.fastmcp import FastMCP
from mcp.shared.message import SessionMessage
MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]]
@asynccontextmanager
async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]:
"""
Creates a pair of bidirectional memory streams for client-server communication.
Returns:
A tuple of (client_streams, server_streams) where each is a tuple of
(read_stream, write_stream)
"""
# Create streams for both directions
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
client_streams = (server_to_client_receive, client_to_server_send)
server_streams = (client_to_server_receive, server_to_client_send)
async with (
server_to_client_receive,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
):
yield client_streams, server_streams
@asynccontextmanager
async def create_connected_server_and_client_session(
server: Server[Any] | FastMCP,
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
# TODO(Marcelo): we should have a proper `Client` that can use this "in-memory transport",
# and we should expose a method in the `FastMCP` so we don't access a private attribute.
if isinstance(server, FastMCP): # pragma: no cover
server = server._mcp_server # type: ignore[reportPrivateUsage]
async with create_client_server_memory_streams() as (client_streams, server_streams):
client_read, client_write = client_streams
server_read, server_write = server_streams
# Create a cancel scope for the server task
async with anyio.create_task_group() as tg:
tg.start_soon(
lambda: server.run(
server_read,
server_write,
server.create_initialization_options(),
raise_exceptions=raise_exceptions,
)
)
try:
async with ClientSession(
read_stream=client_read,
write_stream=client_write,
read_timeout_seconds=read_timeout_seconds,
sampling_callback=sampling_callback,
list_roots_callback=list_roots_callback,
logging_callback=logging_callback,
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
) as client_session:
await client_session.initialize()
yield client_session
finally: # pragma: no cover
tg.cancel_scope.cancel()

View File

@@ -0,0 +1,50 @@
"""
Message wrapper with metadata support.
This module defines a wrapper type that combines JSONRPCMessage with metadata
to support transport-specific features like resumability.
"""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from mcp.types import JSONRPCMessage, RequestId
ResumptionToken = str
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
# Callback type for closing SSE streams without terminating
CloseSSEStreamCallback = Callable[[], Awaitable[None]]
@dataclass
class ClientMessageMetadata:
"""Metadata specific to client messages."""
resumption_token: ResumptionToken | None = None
on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None
@dataclass
class ServerMessageMetadata:
"""Metadata specific to server messages."""
related_request_id: RequestId | None = None
# Request-specific context (e.g., headers, auth info)
request_context: object | None = None
# Callback to close SSE stream for the current request without terminating
close_sse_stream: CloseSSEStreamCallback | None = None
# Callback to close the standalone GET SSE stream (for unsolicited notifications)
close_standalone_sse_stream: CloseSSEStreamCallback | None = None
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
@dataclass
class SessionMessage:
"""A message with specific metadata for transport-specific features."""
message: JSONRPCMessage
metadata: MessageMetadata = None

View File

@@ -0,0 +1,45 @@
"""Utility functions for working with metadata in MCP types.
These utilities are primarily intended for client-side usage to properly display
human-readable names in user interfaces in a spec compliant way.
"""
from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool
def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implementation) -> str:
"""
Get the display name for an MCP object with proper precedence.
This is a client-side utility function designed to help MCP clients display
human-readable names in their user interfaces. When servers provide a 'title'
field, it should be preferred over the programmatic 'name' field for display.
For tools: title > annotations.title > name
For other objects: title > name
Example:
# In a client displaying available tools
tools = await session.list_tools()
for tool in tools.tools:
display_name = get_display_name(tool)
print(f"Available tool: {display_name}")
Args:
obj: An MCP object with name and optional title fields
Returns:
The display name to use for UI presentation
"""
if isinstance(obj, Tool):
# Tools have special precedence: title > annotations.title > name
if hasattr(obj, "title") and obj.title is not None:
return obj.title
if obj.annotations and hasattr(obj.annotations, "title") and obj.annotations.title is not None:
return obj.annotations.title
return obj.name
else:
# All other objects: title > name
if hasattr(obj, "title") and obj.title is not None:
return obj.title
return obj.name

View File

@@ -0,0 +1,58 @@
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Generic
from pydantic import BaseModel
from mcp.shared.context import LifespanContextT, RequestContext
from mcp.shared.session import (
BaseSession,
ReceiveNotificationT,
ReceiveRequestT,
SendNotificationT,
SendRequestT,
SendResultT,
)
from mcp.types import ProgressToken
class Progress(BaseModel):
progress: float
total: float | None
@dataclass
class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]):
session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]
progress_token: ProgressToken
total: float | None
current: float = field(default=0.0, init=False)
async def progress(self, amount: float, message: str | None = None) -> None:
self.current += amount
await self.session.send_progress_notification(
self.progress_token, self.current, total=self.total, message=message
)
@contextmanager
def progress(
ctx: RequestContext[
BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
LifespanContextT,
],
total: float | None = None,
) -> Generator[
ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT],
None,
]:
if ctx.meta is None or ctx.meta.progressToken is None: # pragma: no cover
raise ValueError("No progress token provided")
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
try:
yield progress_ctx
finally:
pass

View File

@@ -0,0 +1,63 @@
"""
ResponseRouter - Protocol for pluggable response routing.
This module defines a protocol for routing JSON-RPC responses to alternative
handlers before falling back to the default response stream mechanism.
The primary use case is task-augmented requests: when a TaskSession enqueues
a request (like elicitation), the response needs to be routed back to the
waiting resolver instead of the normal response stream.
Design:
- Protocol-based for testability and flexibility
- Returns bool to indicate if response was handled
- Supports both success responses and errors
"""
from typing import Any, Protocol
from mcp.types import ErrorData, RequestId
class ResponseRouter(Protocol):
"""
Protocol for routing responses to alternative handlers.
Implementations check if they have a pending request for the given ID
and deliver the response/error to the appropriate handler.
Example:
class TaskResultHandler(ResponseRouter):
def route_response(self, request_id, response):
resolver = self._pending_requests.pop(request_id, None)
if resolver:
resolver.set_result(response)
return True
return False
"""
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
"""
Try to route a response to a pending request handler.
Args:
request_id: The JSON-RPC request ID from the response
response: The response result data
Returns:
True if the response was handled, False otherwise
"""
... # pragma: no cover
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
"""
Try to route an error to a pending request handler.
Args:
request_id: The JSON-RPC request ID from the error response
error: The error data
Returns:
True if the error was handled, False otherwise
"""
... # pragma: no cover

View File

@@ -0,0 +1,550 @@
import logging
from collections.abc import Callable
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar
import anyio
import httpx
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from typing_extensions import Self
from mcp.shared.exceptions import McpError
from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage
from mcp.shared.response_router import ResponseRouter
from mcp.types import (
CONNECTION_CLOSED,
INVALID_PARAMS,
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ProgressNotification,
RequestParams,
ServerNotification,
ServerRequest,
ServerResult,
)
SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
RequestId = str | int
class ProgressFnT(Protocol):
"""Protocol for progress notification callbacks."""
async def __call__(
self, progress: float, total: float | None, message: str | None
) -> None: ... # pragma: no branch
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:
Example:
with request_responder as resp:
await resp.respond(result)
The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""
def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: """BaseSession[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT
]""",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
message_metadata: MessageMetadata = None,
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self.message_metadata = message_metadata
self._session = session
self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
self._cancel_scope.__enter__()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed: # pragma: no branch
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_scope: # pragma: no cover
raise RuntimeError("No active cancel scope")
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
async def respond(self, response: SendResultT | ErrorData) -> None:
"""Send a response for this request.
Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered: # pragma: no cover
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"
if not self.cancelled: # pragma: no branch
self._completed = True
await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id, response=response
)
async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered: # pragma: no cover
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope: # pragma: no cover
raise RuntimeError("No active cancel scope")
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response( # type: ignore[reportPrivateUsage]
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)
@property
def in_flight(self) -> bool: # pragma: no cover
return not self._completed and not self.cancelled
@property
def cancelled(self) -> bool: # pragma: no cover
return self._cancel_scope.cancel_called
class BaseSession(
Generic[
SendRequestT,
SendNotificationT,
SendResultT,
ReceiveRequestT,
ReceiveNotificationT,
],
):
"""
Implements an MCP "session" on top of read/write streams, including features
like request/response linking, notifications, and progress.
This class is an async context manager that automatically starts processing
messages when entered.
"""
_response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_response_routers: list["ResponseRouter"]
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
receive_request_type: type[ReceiveRequestT],
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._response_streams = {}
self._request_id = 0
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()
def add_response_router(self, router: ResponseRouter) -> None:
"""
Register a response router to handle responses for non-standard requests.
Response routers are checked in order before falling back to the default
response stream mechanism. This is used by TaskResultHandler to route
responses for queued task requests back to their resolvers.
WARNING: This is an experimental API that may change without notice.
Args:
router: A ResponseRouter implementation
"""
self._response_routers.append(router)
async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
async def send_request(
self,
request: SendRequestT,
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
will take precedence over the session read timeout.
Do not use this method to emit notifications! Use send_notification()
instead.
"""
request_id = self._request_id
self._request_id = request_id + 1
response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
self._response_streams[request_id] = response_stream
# Set up progress token if progress callback is provided
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
if progress_callback is not None: # pragma: no cover
# Use request_id as progress token
if "params" not in request_data:
request_data["params"] = {}
if "_meta" not in request_data["params"]: # pragma: no branch
request_data["params"]["_meta"] = {}
request_data["params"]["_meta"]["progressToken"] = request_id
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback
try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request_data,
)
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
# request read timeout takes precedence over session read timeout
timeout = None
if request_read_timeout_seconds is not None: # pragma: no cover
timeout = request_read_timeout_seconds.total_seconds()
elif self._session_read_timeout_seconds is not None: # pragma: no cover
timeout = self._session_read_timeout_seconds.total_seconds()
try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
raise McpError(
ErrorData(
code=httpx.codes.REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to "
f"{request.__class__.__name__}. Waited "
f"{timeout} seconds."
),
)
)
if isinstance(response_or_error, JSONRPCError):
raise McpError(response_or_error.error)
else:
return result_type.model_validate(response_or_error.result)
finally:
self._response_streams.pop(request_id, None)
self._progress_callbacks.pop(request_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()
async def send_notification(
self,
notification: SendNotificationT,
related_request_id: RequestId | None = None,
) -> None:
"""
Emits a notification, which is a one-way message that does not expect
a response.
"""
# Some transport implementations may need to set the related_request_id
# to attribute to the notifications to the request that triggered them.
jsonrpc_notification = JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage( # pragma: no cover
message=JSONRPCMessage(jsonrpc_notification),
metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
)
await self._write_stream.send(session_message)
async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
if isinstance(response, ErrorData):
jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
await self._write_stream.send(session_message)
else:
jsonrpc_response = JSONRPCResponse(
jsonrpc="2.0",
id=request_id,
result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
)
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
await self._write_stream.send(session_message)
async def _receive_loop(self) -> None:
async with (
self._read_stream,
self._write_stream,
):
try:
async for message in self._read_stream:
if isinstance(message, Exception): # pragma: no cover
await self._handle_incoming(message)
elif isinstance(message.message.root, JSONRPCRequest):
try:
validated_request = self._receive_request_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
responder = RequestResponder(
request_id=message.message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
message_metadata=message.metadata,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception as e:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
logging.warning(f"Failed to validate request: {e}")
logging.debug(f"Message that failed validation: {message.message.root}")
error_response = JSONRPCError(
jsonrpc="2.0",
id=message.message.root.id,
error=ErrorData(
code=INVALID_PARAMS,
message="Invalid request parameters",
data="",
),
)
session_message = SessionMessage(message=JSONRPCMessage(error_response))
await self._write_stream.send(session_message)
elif isinstance(message.message.root, JSONRPCNotification):
try:
notification = self._receive_notification_type.model_validate(
message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight: # pragma: no branch
await self._in_flight[cancelled_id].cancel()
else:
# Handle progress notifications callback
if isinstance(notification.root, ProgressNotification): # pragma: no cover
progress_token = notification.root.params.progressToken
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
callback = self._progress_callbacks[progress_token]
try:
await callback(
notification.root.params.progress,
notification.root.params.total,
notification.root.params.message,
)
except Exception as e:
logging.error(
"Progress callback raised an exception: %s",
e,
)
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e: # pragma: no cover
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. Message was: {message.message.root}"
)
else: # Response or error
await self._handle_response(message)
except anyio.ClosedResourceError:
# This is expected when the client disconnects abruptly.
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client") # pragma: no cover
except Exception as e: # pragma: no cover
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
for id, stream in self._response_streams.items():
error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed")
try:
await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error))
await stream.aclose()
except Exception: # pragma: no cover
# Stream might already be closed
pass
self._response_streams.clear()
def _normalize_request_id(self, response_id: RequestId) -> RequestId:
"""
Normalize a response ID to match how request IDs are stored.
Since the client always sends integer IDs, we normalize string IDs
to integers when possible. This matches the TypeScript SDK approach:
https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861
Args:
response_id: The response ID from the incoming message.
Returns:
The normalized ID (int if possible, otherwise original value).
"""
if isinstance(response_id, str):
try:
return int(response_id)
except ValueError:
logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests")
return response_id
async def _handle_response(self, message: SessionMessage) -> None:
"""
Handle an incoming response or error message.
Checks response routers first (e.g., for task-related responses),
then falls back to the normal response stream mechanism.
"""
root = message.message.root
# This check is always true at runtime: the caller (_receive_loop) only invokes
# this method in the else branch after checking for JSONRPCRequest and
# JSONRPCNotification. However, the type checker can't infer this from the
# method signature, so we need this guard for type narrowing.
if not isinstance(root, JSONRPCResponse | JSONRPCError):
return # pragma: no cover
# Normalize response ID to handle type mismatches (e.g., "0" vs 0)
response_id = self._normalize_request_id(root.id)
# First, check response routers (e.g., TaskResultHandler)
if isinstance(root, JSONRPCError):
# Route error to routers
for router in self._response_routers:
if router.route_error(response_id, root.error):
return # Handled
else:
# Route success response to routers
response_data: dict[str, Any] = root.result or {}
for router in self._response_routers:
if router.route_response(response_id, response_data):
return # Handled
# Fall back to normal response streams
stream = self._response_streams.pop(response_id, None)
if stream: # pragma: no cover
await stream.send(root)
else: # pragma: no cover
await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
"""
Can be overridden by subclasses to handle a request without needing to
listen on the message stream.
If the request is responded to within this method, it will not be
forwarded on to the message stream.
"""
async def _received_notification(self, notification: ReceiveNotificationT) -> None:
"""
Can be overridden by subclasses to handle a notification without needing
to listen on the message stream.
"""
async def send_progress_notification(
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""
Sends a progress notification for a request that is currently being
processed.
"""
async def _handle_incoming(
self,
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
) -> None:
"""A generic handler for incoming messages. Overwritten by subclasses."""
pass # pragma: no cover

View File

@@ -0,0 +1,129 @@
"""Tool name validation utilities according to SEP-986.
Tool names SHOULD be between 1 and 128 characters in length (inclusive).
Tool names are case-sensitive.
Allowed characters: uppercase and lowercase ASCII letters (A-Z, a-z),
digits (0-9), underscore (_), dash (-), and dot (.).
Tool names SHOULD NOT contain spaces, commas, or other special characters.
See: https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# Regular expression for valid tool names according to SEP-986 specification
TOOL_NAME_REGEX = re.compile(r"^[A-Za-z0-9._-]{1,128}$")
# SEP reference URL for warning messages
SEP_986_URL = "https://modelcontextprotocol.io/specification/2025-11-25/server/tools#tool-names"
@dataclass
class ToolNameValidationResult:
"""Result of tool name validation.
Attributes:
is_valid: Whether the tool name conforms to SEP-986 requirements.
warnings: List of warning messages for non-conforming aspects.
"""
is_valid: bool
warnings: list[str] = field(default_factory=lambda: [])
def validate_tool_name(name: str) -> ToolNameValidationResult:
"""Validate a tool name according to the SEP-986 specification.
Args:
name: The tool name to validate.
Returns:
ToolNameValidationResult containing validation status and any warnings.
"""
warnings: list[str] = []
# Check for empty name
if not name:
return ToolNameValidationResult(
is_valid=False,
warnings=["Tool name cannot be empty"],
)
# Check length
if len(name) > 128:
return ToolNameValidationResult(
is_valid=False,
warnings=[f"Tool name exceeds maximum length of 128 characters (current: {len(name)})"],
)
# Check for problematic patterns (warnings, not validation failures)
if " " in name:
warnings.append("Tool name contains spaces, which may cause parsing issues")
if "," in name:
warnings.append("Tool name contains commas, which may cause parsing issues")
# Check for potentially confusing leading/trailing characters
if name.startswith("-") or name.endswith("-"):
warnings.append("Tool name starts or ends with a dash, which may cause parsing issues in some contexts")
if name.startswith(".") or name.endswith("."):
warnings.append("Tool name starts or ends with a dot, which may cause parsing issues in some contexts")
# Check for invalid characters
if not TOOL_NAME_REGEX.match(name):
# Find all invalid characters (unique, preserving order)
invalid_chars: list[str] = []
seen: set[str] = set()
for char in name:
if not re.match(r"[A-Za-z0-9._-]", char) and char not in seen:
invalid_chars.append(char)
seen.add(char)
warnings.append(f"Tool name contains invalid characters: {', '.join(repr(c) for c in invalid_chars)}")
warnings.append("Allowed characters are: A-Z, a-z, 0-9, underscore (_), dash (-), and dot (.)")
return ToolNameValidationResult(is_valid=False, warnings=warnings)
return ToolNameValidationResult(is_valid=True, warnings=warnings)
def issue_tool_name_warning(name: str, warnings: list[str]) -> None:
"""Log warnings for non-conforming tool names.
Args:
name: The tool name that triggered the warnings.
warnings: List of warning messages to log.
"""
if not warnings:
return
logger.warning(f'Tool name validation warning for "{name}":')
for warning in warnings:
logger.warning(f" - {warning}")
logger.warning("Tool registration will proceed, but this may cause compatibility issues.")
logger.warning("Consider updating the tool name to conform to the MCP tool naming standard.")
logger.warning(f"See SEP-986 ({SEP_986_URL}) for more details.")
def validate_and_warn_tool_name(name: str) -> bool:
"""Validate a tool name and issue warnings for non-conforming names.
This is the primary entry point for tool name validation. It validates
the name and logs any warnings via the logging module.
Args:
name: The tool name to validate.
Returns:
True if the name is valid, False otherwise.
"""
result = validate_tool_name(name)
issue_tool_name_warning(name, result.warnings)
return result.is_valid

View File

@@ -0,0 +1,3 @@
from mcp.types import LATEST_PROTOCOL_VERSION
SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", "2025-06-18", LATEST_PROTOCOL_VERSION]