Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Pure task state management for MCP.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Import directly from submodules:
|
||||
- mcp.shared.experimental.tasks.store.TaskStore
|
||||
- mcp.shared.experimental.tasks.context.TaskContext
|
||||
- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore
|
||||
- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue
|
||||
- mcp.shared.experimental.tasks.helpers.is_terminal
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Tasks capability checking utilities.
|
||||
|
||||
This module provides functions for checking and requiring task-related
|
||||
capabilities. All tasks capability logic is centralized here to keep
|
||||
the main session code clean.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.types import (
|
||||
INVALID_REQUEST,
|
||||
ClientCapabilities,
|
||||
ClientTasksCapability,
|
||||
ErrorData,
|
||||
)
|
||||
|
||||
|
||||
def check_tasks_capability(
|
||||
required: ClientTasksCapability,
|
||||
client: ClientTasksCapability,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if client's tasks capability matches the required capability.
|
||||
|
||||
Args:
|
||||
required: The capability being checked for
|
||||
client: The client's declared capabilities
|
||||
|
||||
Returns:
|
||||
True if client has the required capability, False otherwise
|
||||
"""
|
||||
if required.requests is None:
|
||||
return True
|
||||
if client.requests is None:
|
||||
return False
|
||||
|
||||
# Check elicitation.create
|
||||
if required.requests.elicitation is not None:
|
||||
if client.requests.elicitation is None:
|
||||
return False
|
||||
if required.requests.elicitation.create is not None:
|
||||
if client.requests.elicitation.create is None:
|
||||
return False
|
||||
|
||||
# Check sampling.createMessage
|
||||
if required.requests.sampling is not None:
|
||||
if client.requests.sampling is None:
|
||||
return False
|
||||
if required.requests.sampling.createMessage is not None:
|
||||
if client.requests.sampling.createMessage is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool:
|
||||
"""Check if capabilities include task-augmented elicitation support."""
|
||||
if caps.tasks is None:
|
||||
return False
|
||||
if caps.tasks.requests is None:
|
||||
return False
|
||||
if caps.tasks.requests.elicitation is None:
|
||||
return False
|
||||
return caps.tasks.requests.elicitation.create is not None
|
||||
|
||||
|
||||
def has_task_augmented_sampling(caps: ClientCapabilities) -> bool:
|
||||
"""Check if capabilities include task-augmented sampling support."""
|
||||
if caps.tasks is None:
|
||||
return False
|
||||
if caps.tasks.requests is None:
|
||||
return False
|
||||
if caps.tasks.requests.sampling is None:
|
||||
return False
|
||||
return caps.tasks.requests.sampling.createMessage is not None
|
||||
|
||||
|
||||
def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None:
|
||||
"""
|
||||
Raise McpError if client doesn't support task-augmented elicitation.
|
||||
|
||||
Args:
|
||||
client_caps: The client's declared capabilities, or None if not initialized
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
"""
|
||||
if client_caps is None or not has_task_augmented_elicitation(client_caps):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support task-augmented elicitation",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None:
|
||||
"""
|
||||
Raise McpError if client doesn't support task-augmented sampling.
|
||||
|
||||
Args:
|
||||
client_caps: The client's declared capabilities, or None if not initialized
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling
|
||||
"""
|
||||
if client_caps is None or not has_task_augmented_sampling(client_caps):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support task-augmented sampling",
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
TaskContext - Pure task state management.
|
||||
|
||||
This module provides TaskContext, which manages task state without any
|
||||
server/session dependencies. It can be used standalone for distributed
|
||||
workers or wrapped by ServerTaskContext for full server integration.
|
||||
"""
|
||||
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task
|
||||
|
||||
|
||||
class TaskContext:
|
||||
"""
|
||||
Pure task state management - no session dependencies.
|
||||
|
||||
This class handles:
|
||||
- Task state (status, result)
|
||||
- Cancellation tracking
|
||||
- Store interactions
|
||||
|
||||
For server-integrated features (elicit, create_message, notifications),
|
||||
use ServerTaskContext from mcp.server.experimental.
|
||||
|
||||
Example (distributed worker):
|
||||
async def worker_job(task_id: str):
|
||||
store = RedisTaskStore(redis_url)
|
||||
task = await store.get_task(task_id)
|
||||
ctx = TaskContext(task=task, store=store)
|
||||
|
||||
await ctx.update_status("Working...")
|
||||
result = await do_work()
|
||||
await ctx.complete(result)
|
||||
"""
|
||||
|
||||
def __init__(self, task: Task, store: TaskStore):
|
||||
self._task = task
|
||||
self._store = store
|
||||
self._cancelled = False
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""The task identifier."""
|
||||
return self._task.taskId
|
||||
|
||||
@property
|
||||
def task(self) -> Task:
|
||||
"""The current task state."""
|
||||
return self._task
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Whether cancellation has been requested."""
|
||||
return self._cancelled
|
||||
|
||||
def request_cancellation(self) -> None:
|
||||
"""
|
||||
Request cancellation of this task.
|
||||
|
||||
This sets is_cancelled=True. Task work should check this
|
||||
periodically and exit gracefully if set.
|
||||
"""
|
||||
self._cancelled = True
|
||||
|
||||
async def update_status(self, message: str) -> None:
|
||||
"""
|
||||
Update the task's status message.
|
||||
|
||||
Args:
|
||||
message: The new status message
|
||||
"""
|
||||
self._task = await self._store.update_task(
|
||||
self.task_id,
|
||||
status_message=message,
|
||||
)
|
||||
|
||||
async def complete(self, result: Result) -> None:
|
||||
"""
|
||||
Mark the task as completed with the given result.
|
||||
|
||||
Args:
|
||||
result: The task result
|
||||
"""
|
||||
await self._store.store_result(self.task_id, result)
|
||||
self._task = await self._store.update_task(
|
||||
self.task_id,
|
||||
status=TASK_STATUS_COMPLETED,
|
||||
)
|
||||
|
||||
async def fail(self, error: str) -> None:
|
||||
"""
|
||||
Mark the task as failed with an error message.
|
||||
|
||||
Args:
|
||||
error: The error message
|
||||
"""
|
||||
self._task = await self._store.update_task(
|
||||
self.task_id,
|
||||
status=TASK_STATUS_FAILED,
|
||||
status_message=error,
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Helper functions for pure task management.
|
||||
|
||||
These helpers work with pure TaskContext and don't require server dependencies.
|
||||
For server-integrated task helpers, use mcp.server.experimental.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.context import TaskContext
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import (
|
||||
INVALID_PARAMS,
|
||||
TASK_STATUS_CANCELLED,
|
||||
TASK_STATUS_COMPLETED,
|
||||
TASK_STATUS_FAILED,
|
||||
TASK_STATUS_WORKING,
|
||||
CancelTaskResult,
|
||||
ErrorData,
|
||||
Task,
|
||||
TaskMetadata,
|
||||
TaskStatus,
|
||||
)
|
||||
|
||||
# Metadata key for model-immediate-response (per MCP spec)
|
||||
# Servers MAY include this in CreateTaskResult._meta to provide an immediate
|
||||
# response string while the task executes in the background.
|
||||
MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response"
|
||||
|
||||
# Metadata key for associating requests with a task (per MCP spec)
|
||||
RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task"
|
||||
|
||||
|
||||
def is_terminal(status: TaskStatus) -> bool:
|
||||
"""
|
||||
Check if a task status represents a terminal state.
|
||||
|
||||
Terminal states are those where the task has finished and will not change.
|
||||
|
||||
Args:
|
||||
status: The task status to check
|
||||
|
||||
Returns:
|
||||
True if the status is terminal (completed, failed, or cancelled)
|
||||
"""
|
||||
return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED)
|
||||
|
||||
|
||||
async def cancel_task(
|
||||
store: TaskStore,
|
||||
task_id: str,
|
||||
) -> CancelTaskResult:
|
||||
"""
|
||||
Cancel a task with spec-compliant validation.
|
||||
|
||||
Per spec: "Receivers MUST reject cancellation of terminal status tasks
|
||||
with -32602 (Invalid params)"
|
||||
|
||||
This helper validates that the task exists and is not in a terminal state
|
||||
before setting it to "cancelled".
|
||||
|
||||
Args:
|
||||
store: The task store
|
||||
task_id: The task identifier to cancel
|
||||
|
||||
Returns:
|
||||
CancelTaskResult with the cancelled task state
|
||||
|
||||
Raises:
|
||||
McpError: With INVALID_PARAMS (-32602) if:
|
||||
- Task does not exist
|
||||
- Task is already in a terminal state (completed, failed, cancelled)
|
||||
|
||||
Example:
|
||||
@server.experimental.cancel_task()
|
||||
async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult:
|
||||
return await cancel_task(store, request.params.taskId)
|
||||
"""
|
||||
task = await store.get_task(task_id)
|
||||
if task is None:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Task not found: {task_id}",
|
||||
)
|
||||
)
|
||||
|
||||
if is_terminal(task.status):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Cannot cancel task in terminal state '{task.status}'",
|
||||
)
|
||||
)
|
||||
|
||||
# Update task to cancelled status
|
||||
cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED)
|
||||
return CancelTaskResult(**cancelled_task.model_dump())
|
||||
|
||||
|
||||
def generate_task_id() -> str:
|
||||
"""Generate a unique task ID."""
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
def create_task_state(
|
||||
metadata: TaskMetadata,
|
||||
task_id: str | None = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Create a Task object with initial state.
|
||||
|
||||
This is a helper for TaskStore implementations.
|
||||
|
||||
Args:
|
||||
metadata: Task metadata
|
||||
task_id: Optional task ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
A new Task in "working" status
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
return Task(
|
||||
taskId=task_id or generate_task_id(),
|
||||
status=TASK_STATUS_WORKING,
|
||||
createdAt=now,
|
||||
lastUpdatedAt=now,
|
||||
ttl=metadata.ttl,
|
||||
pollInterval=500, # Default 500ms poll interval
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def task_execution(
|
||||
task_id: str,
|
||||
store: TaskStore,
|
||||
) -> AsyncIterator[TaskContext]:
|
||||
"""
|
||||
Context manager for safe task execution (pure, no server dependencies).
|
||||
|
||||
Loads a task from the store and provides a TaskContext for the work.
|
||||
If an unhandled exception occurs, the task is automatically marked as failed
|
||||
and the exception is suppressed (since the failure is captured in task state).
|
||||
|
||||
This is useful for distributed workers that don't have a server session.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier to execute
|
||||
store: The task store (must be accessible by the worker)
|
||||
|
||||
Yields:
|
||||
TaskContext for updating status and completing/failing the task
|
||||
|
||||
Raises:
|
||||
ValueError: If the task is not found in the store
|
||||
|
||||
Example (distributed worker):
|
||||
async def worker_process(task_id: str):
|
||||
store = RedisTaskStore(redis_url)
|
||||
async with task_execution(task_id, store) as ctx:
|
||||
await ctx.update_status("Working...")
|
||||
result = await do_work()
|
||||
await ctx.complete(result)
|
||||
"""
|
||||
task = await store.get_task(task_id)
|
||||
if task is None:
|
||||
raise ValueError(f"Task {task_id} not found")
|
||||
|
||||
ctx = TaskContext(task, store)
|
||||
try:
|
||||
yield ctx
|
||||
except Exception as e:
|
||||
# Auto-fail the task if an exception occurs and task isn't already terminal
|
||||
# Exception is suppressed since failure is captured in task state
|
||||
if not is_terminal(ctx.task.status):
|
||||
await ctx.fail(str(e))
|
||||
# Don't re-raise - the failure is recorded in task state
|
||||
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
In-memory implementation of TaskStore for demonstration purposes.
|
||||
|
||||
This implementation stores all tasks in memory and provides automatic cleanup
|
||||
based on the TTL duration specified in the task metadata using lazy expiration.
|
||||
|
||||
Note: This is not suitable for production use as all data is lost on restart.
|
||||
For production, consider implementing TaskStore with a database or distributed cache.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import Result, Task, TaskMetadata, TaskStatus
|
||||
|
||||
|
||||
@dataclass
|
||||
class StoredTask:
|
||||
"""Internal storage representation of a task."""
|
||||
|
||||
task: Task
|
||||
result: Result | None = None
|
||||
# Time when this task should be removed (None = never)
|
||||
expires_at: datetime | None = field(default=None)
|
||||
|
||||
|
||||
class InMemoryTaskStore(TaskStore):
|
||||
"""
|
||||
A simple in-memory implementation of TaskStore.
|
||||
|
||||
Features:
|
||||
- Automatic TTL-based cleanup (lazy expiration)
|
||||
- Thread-safe for single-process async use
|
||||
- Pagination support for list_tasks
|
||||
|
||||
Limitations:
|
||||
- All data lost on restart
|
||||
- Not suitable for distributed systems
|
||||
- No persistence
|
||||
|
||||
For production, implement TaskStore with Redis, PostgreSQL, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, page_size: int = 10) -> None:
|
||||
self._tasks: dict[str, StoredTask] = {}
|
||||
self._page_size = page_size
|
||||
self._update_events: dict[str, anyio.Event] = {}
|
||||
|
||||
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
|
||||
"""Calculate expiry time from TTL in milliseconds."""
|
||||
if ttl_ms is None:
|
||||
return None
|
||||
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
|
||||
|
||||
def _is_expired(self, stored: StoredTask) -> bool:
|
||||
"""Check if a task has expired."""
|
||||
if stored.expires_at is None:
|
||||
return False
|
||||
return datetime.now(timezone.utc) >= stored.expires_at
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove all expired tasks. Called lazily during access operations."""
|
||||
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
|
||||
for task_id in expired_ids:
|
||||
del self._tasks[task_id]
|
||||
|
||||
async def create_task(
|
||||
self,
|
||||
metadata: TaskMetadata,
|
||||
task_id: str | None = None,
|
||||
) -> Task:
|
||||
"""Create a new task with the given metadata."""
|
||||
# Cleanup expired tasks on access
|
||||
self._cleanup_expired()
|
||||
|
||||
task = create_task_state(metadata, task_id)
|
||||
|
||||
if task.taskId in self._tasks:
|
||||
raise ValueError(f"Task with ID {task.taskId} already exists")
|
||||
|
||||
stored = StoredTask(
|
||||
task=task,
|
||||
expires_at=self._calculate_expiry(metadata.ttl),
|
||||
)
|
||||
self._tasks[task.taskId] = stored
|
||||
|
||||
# Return a copy to prevent external modification
|
||||
return Task(**task.model_dump())
|
||||
|
||||
async def get_task(self, task_id: str) -> Task | None:
|
||||
"""Get a task by ID."""
|
||||
# Cleanup expired tasks on access
|
||||
self._cleanup_expired()
|
||||
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
# Return a copy to prevent external modification
|
||||
return Task(**stored.task.model_dump())
|
||||
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus | None = None,
|
||||
status_message: str | None = None,
|
||||
) -> Task:
|
||||
"""Update a task's status and/or message."""
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
raise ValueError(f"Task with ID {task_id} not found")
|
||||
|
||||
# Per spec: Terminal states MUST NOT transition to any other status
|
||||
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
|
||||
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
|
||||
|
||||
status_changed = False
|
||||
if status is not None and stored.task.status != status:
|
||||
stored.task.status = status
|
||||
status_changed = True
|
||||
|
||||
if status_message is not None:
|
||||
stored.task.statusMessage = status_message
|
||||
|
||||
# Update lastUpdatedAt on any change
|
||||
stored.task.lastUpdatedAt = datetime.now(timezone.utc)
|
||||
|
||||
# If task is now terminal and has TTL, reset expiry timer
|
||||
if status is not None and is_terminal(status) and stored.task.ttl is not None:
|
||||
stored.expires_at = self._calculate_expiry(stored.task.ttl)
|
||||
|
||||
# Notify waiters if status changed
|
||||
if status_changed:
|
||||
await self.notify_update(task_id)
|
||||
|
||||
return Task(**stored.task.model_dump())
|
||||
|
||||
async def store_result(self, task_id: str, result: Result) -> None:
|
||||
"""Store the result for a task."""
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
raise ValueError(f"Task with ID {task_id} not found")
|
||||
|
||||
stored.result = result
|
||||
|
||||
async def get_result(self, task_id: str) -> Result | None:
|
||||
"""Get the stored result for a task."""
|
||||
stored = self._tasks.get(task_id)
|
||||
if stored is None:
|
||||
return None
|
||||
|
||||
return stored.result
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> tuple[list[Task], str | None]:
|
||||
"""List tasks with pagination."""
|
||||
# Cleanup expired tasks on access
|
||||
self._cleanup_expired()
|
||||
|
||||
all_task_ids = list(self._tasks.keys())
|
||||
|
||||
start_index = 0
|
||||
if cursor is not None:
|
||||
try:
|
||||
cursor_index = all_task_ids.index(cursor)
|
||||
start_index = cursor_index + 1
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid cursor: {cursor}")
|
||||
|
||||
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
|
||||
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
|
||||
|
||||
# Determine next cursor
|
||||
next_cursor = None
|
||||
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
|
||||
next_cursor = page_task_ids[-1]
|
||||
|
||||
return tasks, next_cursor
|
||||
|
||||
async def delete_task(self, task_id: str) -> bool:
|
||||
"""Delete a task."""
|
||||
if task_id not in self._tasks:
|
||||
return False
|
||||
|
||||
del self._tasks[task_id]
|
||||
return True
|
||||
|
||||
async def wait_for_update(self, task_id: str) -> None:
|
||||
"""Wait until the task status changes."""
|
||||
if task_id not in self._tasks:
|
||||
raise ValueError(f"Task with ID {task_id} not found")
|
||||
|
||||
# Create a fresh event for waiting (anyio.Event can't be cleared)
|
||||
self._update_events[task_id] = anyio.Event()
|
||||
event = self._update_events[task_id]
|
||||
await event.wait()
|
||||
|
||||
async def notify_update(self, task_id: str) -> None:
|
||||
"""Signal that a task has been updated."""
|
||||
if task_id in self._update_events:
|
||||
self._update_events[task_id].set()
|
||||
|
||||
# --- Testing/debugging helpers ---
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
|
||||
self._tasks.clear()
|
||||
self._update_events.clear()
|
||||
|
||||
def get_all_tasks(self) -> list[Task]:
|
||||
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
|
||||
self._cleanup_expired()
|
||||
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]
|
||||
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
TaskMessageQueue - FIFO queue for task-related messages.
|
||||
|
||||
This implements the core message queue pattern from the MCP Tasks spec.
|
||||
When a handler needs to send a request (like elicitation) during a task-augmented
|
||||
request, the message is enqueued instead of sent directly. Messages are delivered
|
||||
to the client only through the `tasks/result` endpoint.
|
||||
|
||||
This pattern enables:
|
||||
1. Decoupling request handling from message delivery
|
||||
2. Proper bidirectional communication via the tasks/result stream
|
||||
3. Automatic status management (working <-> input_required)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedMessage:
|
||||
"""
|
||||
A message queued for delivery via tasks/result.
|
||||
|
||||
Messages are stored with their type and a resolver for requests
|
||||
that expect responses.
|
||||
"""
|
||||
|
||||
type: Literal["request", "notification"]
|
||||
"""Whether this is a request (expects response) or notification (one-way)."""
|
||||
|
||||
message: JSONRPCRequest | JSONRPCNotification
|
||||
"""The JSON-RPC message to send."""
|
||||
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""When the message was enqueued."""
|
||||
|
||||
resolver: Resolver[dict[str, Any]] | None = None
|
||||
"""Resolver to set when response arrives (only for requests)."""
|
||||
|
||||
original_request_id: RequestId | None = None
|
||||
"""The original request ID used internally, for routing responses back."""
|
||||
|
||||
|
||||
class TaskMessageQueue(ABC):
|
||||
"""
|
||||
Abstract interface for task message queuing.
|
||||
|
||||
This is a FIFO queue that stores messages to be delivered via `tasks/result`.
|
||||
When a task-augmented handler calls elicit() or sends a notification, the
|
||||
message is enqueued here instead of being sent directly to the client.
|
||||
|
||||
The `tasks/result` handler then dequeues and sends these messages through
|
||||
the transport, with `relatedRequestId` set to the tasks/result request ID
|
||||
so responses are routed correctly.
|
||||
|
||||
Implementations can use in-memory storage, Redis, etc.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
|
||||
"""
|
||||
Add a message to the queue for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
message: The message to enqueue
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def dequeue(self, task_id: str) -> QueuedMessage | None:
|
||||
"""
|
||||
Remove and return the next message from the queue.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The next message, or None if queue is empty
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def peek(self, task_id: str) -> QueuedMessage | None:
|
||||
"""
|
||||
Return the next message without removing it.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The next message, or None if queue is empty
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def is_empty(self, task_id: str) -> bool:
|
||||
"""
|
||||
Check if the queue is empty for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
True if no messages are queued
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self, task_id: str) -> list[QueuedMessage]:
|
||||
"""
|
||||
Remove and return all messages from the queue.
|
||||
|
||||
This is useful for cleanup when a task is cancelled or completed.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
All queued messages (may be empty)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def wait_for_message(self, task_id: str) -> None:
|
||||
"""
|
||||
Wait until a message is available in the queue.
|
||||
|
||||
This blocks until either:
|
||||
1. A message is enqueued for this task
|
||||
2. The wait is cancelled
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify_message_available(self, task_id: str) -> None:
|
||||
"""
|
||||
Signal that a message is available for a task.
|
||||
|
||||
This wakes up any coroutines waiting in wait_for_message().
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
"""
|
||||
|
||||
|
||||
class InMemoryTaskMessageQueue(TaskMessageQueue):
|
||||
"""
|
||||
In-memory implementation of TaskMessageQueue.
|
||||
|
||||
This is suitable for single-process servers. For distributed systems,
|
||||
implement TaskMessageQueue with Redis, RabbitMQ, etc.
|
||||
|
||||
Features:
|
||||
- FIFO ordering per task
|
||||
- Async wait for message availability
|
||||
- Thread-safe for single-process async use
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._queues: dict[str, list[QueuedMessage]] = {}
|
||||
self._events: dict[str, anyio.Event] = {}
|
||||
|
||||
def _get_queue(self, task_id: str) -> list[QueuedMessage]:
|
||||
"""Get or create the queue for a task."""
|
||||
if task_id not in self._queues:
|
||||
self._queues[task_id] = []
|
||||
return self._queues[task_id]
|
||||
|
||||
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
|
||||
"""Add a message to the queue."""
|
||||
queue = self._get_queue(task_id)
|
||||
queue.append(message)
|
||||
# Signal that a message is available
|
||||
await self.notify_message_available(task_id)
|
||||
|
||||
async def dequeue(self, task_id: str) -> QueuedMessage | None:
|
||||
"""Remove and return the next message."""
|
||||
queue = self._get_queue(task_id)
|
||||
if not queue:
|
||||
return None
|
||||
return queue.pop(0)
|
||||
|
||||
async def peek(self, task_id: str) -> QueuedMessage | None:
|
||||
"""Return the next message without removing it."""
|
||||
queue = self._get_queue(task_id)
|
||||
if not queue:
|
||||
return None
|
||||
return queue[0]
|
||||
|
||||
async def is_empty(self, task_id: str) -> bool:
|
||||
"""Check if the queue is empty."""
|
||||
queue = self._get_queue(task_id)
|
||||
return len(queue) == 0
|
||||
|
||||
async def clear(self, task_id: str) -> list[QueuedMessage]:
|
||||
"""Remove and return all messages."""
|
||||
queue = self._get_queue(task_id)
|
||||
messages = list(queue)
|
||||
queue.clear()
|
||||
return messages
|
||||
|
||||
async def wait_for_message(self, task_id: str) -> None:
|
||||
"""Wait until a message is available."""
|
||||
# Check if there are already messages
|
||||
if not await self.is_empty(task_id):
|
||||
return
|
||||
|
||||
# Create a fresh event for waiting (anyio.Event can't be cleared)
|
||||
self._events[task_id] = anyio.Event()
|
||||
event = self._events[task_id]
|
||||
|
||||
# Double-check after creating event (avoid race condition)
|
||||
if not await self.is_empty(task_id):
|
||||
return
|
||||
|
||||
# Wait for a new message
|
||||
await event.wait()
|
||||
|
||||
async def notify_message_available(self, task_id: str) -> None:
|
||||
"""Signal that a message is available."""
|
||||
if task_id in self._events:
|
||||
self._events[task_id].set()
|
||||
|
||||
def cleanup(self, task_id: str | None = None) -> None:
|
||||
"""
|
||||
Clean up queues and events.
|
||||
|
||||
Args:
|
||||
task_id: If provided, clean up only this task. Otherwise clean up all.
|
||||
"""
|
||||
if task_id is not None:
|
||||
self._queues.pop(task_id, None)
|
||||
self._events.pop(task_id, None)
|
||||
else:
|
||||
self._queues.clear()
|
||||
self._events.clear()
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Shared polling utilities for task operations.
|
||||
|
||||
This module provides generic polling logic that works for both client→server
|
||||
and server→client task polling.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.shared.experimental.tasks.helpers import is_terminal
|
||||
from mcp.types import GetTaskResult
|
||||
|
||||
|
||||
async def poll_until_terminal(
|
||||
get_task: Callable[[str], Awaitable[GetTaskResult]],
|
||||
task_id: str,
|
||||
default_interval_ms: int = 500,
|
||||
) -> AsyncIterator[GetTaskResult]:
|
||||
"""
|
||||
Poll a task until it reaches terminal status.
|
||||
|
||||
This is a generic utility that works for both client→server and server→client
|
||||
polling. The caller provides the get_task function appropriate for their direction.
|
||||
|
||||
Args:
|
||||
get_task: Async function that takes task_id and returns GetTaskResult
|
||||
task_id: The task to poll
|
||||
default_interval_ms: Fallback poll interval if server doesn't specify
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
"""
|
||||
while True:
|
||||
status = await get_task(task_id)
|
||||
yield status
|
||||
|
||||
if is_terminal(status.status):
|
||||
break
|
||||
|
||||
interval_ms = status.pollInterval if status.pollInterval is not None else default_interval_ms
|
||||
await anyio.sleep(interval_ms / 1000)
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
Resolver - An anyio-compatible future-like object for async result passing.
|
||||
|
||||
This provides a simple way to pass a result (or exception) from one coroutine
|
||||
to another without depending on asyncio.Future.
|
||||
"""
|
||||
|
||||
from typing import Generic, TypeVar, cast
|
||||
|
||||
import anyio
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Resolver(Generic[T]):
|
||||
"""
|
||||
A simple resolver for passing results between coroutines.
|
||||
|
||||
Unlike asyncio.Future, this works with any anyio-compatible async backend.
|
||||
|
||||
Usage:
|
||||
resolver: Resolver[str] = Resolver()
|
||||
|
||||
# In one coroutine:
|
||||
resolver.set_result("hello")
|
||||
|
||||
# In another coroutine:
|
||||
result = await resolver.wait() # returns "hello"
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._event = anyio.Event()
|
||||
self._value: T | None = None
|
||||
self._exception: BaseException | None = None
|
||||
|
||||
def set_result(self, value: T) -> None:
|
||||
"""Set the result value and wake up waiters."""
|
||||
if self._event.is_set():
|
||||
raise RuntimeError("Resolver already completed")
|
||||
self._value = value
|
||||
self._event.set()
|
||||
|
||||
def set_exception(self, exc: BaseException) -> None:
|
||||
"""Set an exception and wake up waiters."""
|
||||
if self._event.is_set():
|
||||
raise RuntimeError("Resolver already completed")
|
||||
self._exception = exc
|
||||
self._event.set()
|
||||
|
||||
async def wait(self) -> T:
|
||||
"""Wait for the result and return it, or raise the exception."""
|
||||
await self._event.wait()
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
# If we reach here, set_result() was called, so _value is set
|
||||
return cast(T, self._value)
|
||||
|
||||
def done(self) -> bool:
|
||||
"""Return True if the resolver has been completed."""
|
||||
return self._event.is_set()
|
||||
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
TaskStore - Abstract interface for task state storage.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from mcp.types import Result, Task, TaskMetadata, TaskStatus
|
||||
|
||||
|
||||
class TaskStore(ABC):
|
||||
"""
|
||||
Abstract interface for task state storage.
|
||||
|
||||
This is a pure storage interface - it doesn't manage execution.
|
||||
Implementations can use in-memory storage, databases, Redis, etc.
|
||||
|
||||
All methods are async to support various backends.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_task(
|
||||
self,
|
||||
metadata: TaskMetadata,
|
||||
task_id: str | None = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Create a new task.
|
||||
|
||||
Args:
|
||||
metadata: Task metadata (ttl, etc.)
|
||||
task_id: Optional task ID. If None, implementation should generate one.
|
||||
|
||||
Returns:
|
||||
The created Task with status="working"
|
||||
|
||||
Raises:
|
||||
ValueError: If task_id already exists
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_task(self, task_id: str) -> Task | None:
|
||||
"""
|
||||
Get a task by ID.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The Task, or None if not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: TaskStatus | None = None,
|
||||
status_message: str | None = None,
|
||||
) -> Task:
|
||||
"""
|
||||
Update a task's status and/or message.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
status: New status (if changing)
|
||||
status_message: New status message (if changing)
|
||||
|
||||
Returns:
|
||||
The updated Task
|
||||
|
||||
Raises:
|
||||
ValueError: If task not found
|
||||
ValueError: If attempting to transition from a terminal status
|
||||
(completed, failed, cancelled). Per spec, terminal states
|
||||
MUST NOT transition to any other status.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def store_result(self, task_id: str, result: Result) -> None:
|
||||
"""
|
||||
Store the result for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result: The result to store
|
||||
|
||||
Raises:
|
||||
ValueError: If task not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_result(self, task_id: str) -> Result | None:
|
||||
"""
|
||||
Get the stored result for a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
The stored Result, or None if not available
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> tuple[list[Task], str | None]:
|
||||
"""
|
||||
List tasks with pagination.
|
||||
|
||||
Args:
|
||||
cursor: Optional cursor for pagination
|
||||
|
||||
Returns:
|
||||
Tuple of (tasks, next_cursor). next_cursor is None if no more pages.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def delete_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def wait_for_update(self, task_id: str) -> None:
|
||||
"""
|
||||
Wait until the task status changes.
|
||||
|
||||
This blocks until either:
|
||||
1. The task status changes
|
||||
2. The wait is cancelled
|
||||
|
||||
Used by tasks/result to wait for task completion or status changes.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Raises:
|
||||
ValueError: If task not found
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify_update(self, task_id: str) -> None:
|
||||
"""
|
||||
Signal that a task has been updated.
|
||||
|
||||
This wakes up any coroutines waiting in wait_for_update().
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
"""
|
||||
Reference in New Issue
Block a user