Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,7 @@
"""
Pure experimental MCP features (no server dependencies).
WARNING: These APIs are experimental and may change without notice.
For server-integrated experimental features, use mcp.server.experimental.
"""

View File

@@ -0,0 +1,12 @@
"""
Pure task state management for MCP.
WARNING: These APIs are experimental and may change without notice.
Import directly from submodules:
- mcp.shared.experimental.tasks.store.TaskStore
- mcp.shared.experimental.tasks.context.TaskContext
- mcp.shared.experimental.tasks.in_memory_task_store.InMemoryTaskStore
- mcp.shared.experimental.tasks.message_queue.TaskMessageQueue
- mcp.shared.experimental.tasks.helpers.is_terminal
"""

View File

@@ -0,0 +1,115 @@
"""
Tasks capability checking utilities.
This module provides functions for checking and requiring task-related
capabilities. All tasks capability logic is centralized here to keep
the main session code clean.
WARNING: These APIs are experimental and may change without notice.
"""
from mcp.shared.exceptions import McpError
from mcp.types import (
INVALID_REQUEST,
ClientCapabilities,
ClientTasksCapability,
ErrorData,
)
def check_tasks_capability(
required: ClientTasksCapability,
client: ClientTasksCapability,
) -> bool:
"""
Check if client's tasks capability matches the required capability.
Args:
required: The capability being checked for
client: The client's declared capabilities
Returns:
True if client has the required capability, False otherwise
"""
if required.requests is None:
return True
if client.requests is None:
return False
# Check elicitation.create
if required.requests.elicitation is not None:
if client.requests.elicitation is None:
return False
if required.requests.elicitation.create is not None:
if client.requests.elicitation.create is None:
return False
# Check sampling.createMessage
if required.requests.sampling is not None:
if client.requests.sampling is None:
return False
if required.requests.sampling.createMessage is not None:
if client.requests.sampling.createMessage is None:
return False
return True
def has_task_augmented_elicitation(caps: ClientCapabilities) -> bool:
"""Check if capabilities include task-augmented elicitation support."""
if caps.tasks is None:
return False
if caps.tasks.requests is None:
return False
if caps.tasks.requests.elicitation is None:
return False
return caps.tasks.requests.elicitation.create is not None
def has_task_augmented_sampling(caps: ClientCapabilities) -> bool:
"""Check if capabilities include task-augmented sampling support."""
if caps.tasks is None:
return False
if caps.tasks.requests is None:
return False
if caps.tasks.requests.sampling is None:
return False
return caps.tasks.requests.sampling.createMessage is not None
def require_task_augmented_elicitation(client_caps: ClientCapabilities | None) -> None:
"""
Raise McpError if client doesn't support task-augmented elicitation.
Args:
client_caps: The client's declared capabilities, or None if not initialized
Raises:
McpError: If client doesn't support task-augmented elicitation
"""
if client_caps is None or not has_task_augmented_elicitation(client_caps):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support task-augmented elicitation",
)
)
def require_task_augmented_sampling(client_caps: ClientCapabilities | None) -> None:
"""
Raise McpError if client doesn't support task-augmented sampling.
Args:
client_caps: The client's declared capabilities, or None if not initialized
Raises:
McpError: If client doesn't support task-augmented sampling
"""
if client_caps is None or not has_task_augmented_sampling(client_caps):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support task-augmented sampling",
)
)

View File

@@ -0,0 +1,101 @@
"""
TaskContext - Pure task state management.
This module provides TaskContext, which manages task state without any
server/session dependencies. It can be used standalone for distributed
workers or wrapped by ServerTaskContext for full server integration.
"""
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, Result, Task
class TaskContext:
"""
Pure task state management - no session dependencies.
This class handles:
- Task state (status, result)
- Cancellation tracking
- Store interactions
For server-integrated features (elicit, create_message, notifications),
use ServerTaskContext from mcp.server.experimental.
Example (distributed worker):
async def worker_job(task_id: str):
store = RedisTaskStore(redis_url)
task = await store.get_task(task_id)
ctx = TaskContext(task=task, store=store)
await ctx.update_status("Working...")
result = await do_work()
await ctx.complete(result)
"""
def __init__(self, task: Task, store: TaskStore):
self._task = task
self._store = store
self._cancelled = False
@property
def task_id(self) -> str:
"""The task identifier."""
return self._task.taskId
@property
def task(self) -> Task:
"""The current task state."""
return self._task
@property
def is_cancelled(self) -> bool:
"""Whether cancellation has been requested."""
return self._cancelled
def request_cancellation(self) -> None:
"""
Request cancellation of this task.
This sets is_cancelled=True. Task work should check this
periodically and exit gracefully if set.
"""
self._cancelled = True
async def update_status(self, message: str) -> None:
"""
Update the task's status message.
Args:
message: The new status message
"""
self._task = await self._store.update_task(
self.task_id,
status_message=message,
)
async def complete(self, result: Result) -> None:
"""
Mark the task as completed with the given result.
Args:
result: The task result
"""
await self._store.store_result(self.task_id, result)
self._task = await self._store.update_task(
self.task_id,
status=TASK_STATUS_COMPLETED,
)
async def fail(self, error: str) -> None:
"""
Mark the task as failed with an error message.
Args:
error: The error message
"""
self._task = await self._store.update_task(
self.task_id,
status=TASK_STATUS_FAILED,
status_message=error,
)

View File

@@ -0,0 +1,181 @@
"""
Helper functions for pure task management.
These helpers work with pure TaskContext and don't require server dependencies.
For server-integrated task helpers, use mcp.server.experimental.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from uuid import uuid4
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.context import TaskContext
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import (
INVALID_PARAMS,
TASK_STATUS_CANCELLED,
TASK_STATUS_COMPLETED,
TASK_STATUS_FAILED,
TASK_STATUS_WORKING,
CancelTaskResult,
ErrorData,
Task,
TaskMetadata,
TaskStatus,
)
# Metadata key for model-immediate-response (per MCP spec)
# Servers MAY include this in CreateTaskResult._meta to provide an immediate
# response string while the task executes in the background.
MODEL_IMMEDIATE_RESPONSE_KEY = "io.modelcontextprotocol/model-immediate-response"
# Metadata key for associating requests with a task (per MCP spec)
RELATED_TASK_METADATA_KEY = "io.modelcontextprotocol/related-task"
def is_terminal(status: TaskStatus) -> bool:
"""
Check if a task status represents a terminal state.
Terminal states are those where the task has finished and will not change.
Args:
status: The task status to check
Returns:
True if the status is terminal (completed, failed, or cancelled)
"""
return status in (TASK_STATUS_COMPLETED, TASK_STATUS_FAILED, TASK_STATUS_CANCELLED)
async def cancel_task(
store: TaskStore,
task_id: str,
) -> CancelTaskResult:
"""
Cancel a task with spec-compliant validation.
Per spec: "Receivers MUST reject cancellation of terminal status tasks
with -32602 (Invalid params)"
This helper validates that the task exists and is not in a terminal state
before setting it to "cancelled".
Args:
store: The task store
task_id: The task identifier to cancel
Returns:
CancelTaskResult with the cancelled task state
Raises:
McpError: With INVALID_PARAMS (-32602) if:
- Task does not exist
- Task is already in a terminal state (completed, failed, cancelled)
Example:
@server.experimental.cancel_task()
async def handle_cancel(request: CancelTaskRequest) -> CancelTaskResult:
return await cancel_task(store, request.params.taskId)
"""
task = await store.get_task(task_id)
if task is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task not found: {task_id}",
)
)
if is_terminal(task.status):
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Cannot cancel task in terminal state '{task.status}'",
)
)
# Update task to cancelled status
cancelled_task = await store.update_task(task_id, status=TASK_STATUS_CANCELLED)
return CancelTaskResult(**cancelled_task.model_dump())
def generate_task_id() -> str:
"""Generate a unique task ID."""
return str(uuid4())
def create_task_state(
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""
Create a Task object with initial state.
This is a helper for TaskStore implementations.
Args:
metadata: Task metadata
task_id: Optional task ID (generated if not provided)
Returns:
A new Task in "working" status
"""
now = datetime.now(timezone.utc)
return Task(
taskId=task_id or generate_task_id(),
status=TASK_STATUS_WORKING,
createdAt=now,
lastUpdatedAt=now,
ttl=metadata.ttl,
pollInterval=500, # Default 500ms poll interval
)
@asynccontextmanager
async def task_execution(
task_id: str,
store: TaskStore,
) -> AsyncIterator[TaskContext]:
"""
Context manager for safe task execution (pure, no server dependencies).
Loads a task from the store and provides a TaskContext for the work.
If an unhandled exception occurs, the task is automatically marked as failed
and the exception is suppressed (since the failure is captured in task state).
This is useful for distributed workers that don't have a server session.
Args:
task_id: The task identifier to execute
store: The task store (must be accessible by the worker)
Yields:
TaskContext for updating status and completing/failing the task
Raises:
ValueError: If the task is not found in the store
Example (distributed worker):
async def worker_process(task_id: str):
store = RedisTaskStore(redis_url)
async with task_execution(task_id, store) as ctx:
await ctx.update_status("Working...")
result = await do_work()
await ctx.complete(result)
"""
task = await store.get_task(task_id)
if task is None:
raise ValueError(f"Task {task_id} not found")
ctx = TaskContext(task, store)
try:
yield ctx
except Exception as e:
# Auto-fail the task if an exception occurs and task isn't already terminal
# Exception is suppressed since failure is captured in task state
if not is_terminal(ctx.task.status):
await ctx.fail(str(e))
# Don't re-raise - the failure is recorded in task state

View File

@@ -0,0 +1,219 @@
"""
In-memory implementation of TaskStore for demonstration purposes.
This implementation stores all tasks in memory and provides automatic cleanup
based on the TTL duration specified in the task metadata using lazy expiration.
Note: This is not suitable for production use as all data is lost on restart.
For production, consider implementing TaskStore with a database or distributed cache.
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
import anyio
from mcp.shared.experimental.tasks.helpers import create_task_state, is_terminal
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import Result, Task, TaskMetadata, TaskStatus
@dataclass
class StoredTask:
"""Internal storage representation of a task."""
task: Task
result: Result | None = None
# Time when this task should be removed (None = never)
expires_at: datetime | None = field(default=None)
class InMemoryTaskStore(TaskStore):
"""
A simple in-memory implementation of TaskStore.
Features:
- Automatic TTL-based cleanup (lazy expiration)
- Thread-safe for single-process async use
- Pagination support for list_tasks
Limitations:
- All data lost on restart
- Not suitable for distributed systems
- No persistence
For production, implement TaskStore with Redis, PostgreSQL, etc.
"""
def __init__(self, page_size: int = 10) -> None:
self._tasks: dict[str, StoredTask] = {}
self._page_size = page_size
self._update_events: dict[str, anyio.Event] = {}
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
"""Calculate expiry time from TTL in milliseconds."""
if ttl_ms is None:
return None
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
def _is_expired(self, stored: StoredTask) -> bool:
"""Check if a task has expired."""
if stored.expires_at is None:
return False
return datetime.now(timezone.utc) >= stored.expires_at
def _cleanup_expired(self) -> None:
"""Remove all expired tasks. Called lazily during access operations."""
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
for task_id in expired_ids:
del self._tasks[task_id]
async def create_task(
self,
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""Create a new task with the given metadata."""
# Cleanup expired tasks on access
self._cleanup_expired()
task = create_task_state(metadata, task_id)
if task.taskId in self._tasks:
raise ValueError(f"Task with ID {task.taskId} already exists")
stored = StoredTask(
task=task,
expires_at=self._calculate_expiry(metadata.ttl),
)
self._tasks[task.taskId] = stored
# Return a copy to prevent external modification
return Task(**task.model_dump())
async def get_task(self, task_id: str) -> Task | None:
"""Get a task by ID."""
# Cleanup expired tasks on access
self._cleanup_expired()
stored = self._tasks.get(task_id)
if stored is None:
return None
# Return a copy to prevent external modification
return Task(**stored.task.model_dump())
async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""Update a task's status and/or message."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
# Per spec: Terminal states MUST NOT transition to any other status
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
status_changed = False
if status is not None and stored.task.status != status:
stored.task.status = status
status_changed = True
if status_message is not None:
stored.task.statusMessage = status_message
# Update lastUpdatedAt on any change
stored.task.lastUpdatedAt = datetime.now(timezone.utc)
# If task is now terminal and has TTL, reset expiry timer
if status is not None and is_terminal(status) and stored.task.ttl is not None:
stored.expires_at = self._calculate_expiry(stored.task.ttl)
# Notify waiters if status changed
if status_changed:
await self.notify_update(task_id)
return Task(**stored.task.model_dump())
async def store_result(self, task_id: str, result: Result) -> None:
"""Store the result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
raise ValueError(f"Task with ID {task_id} not found")
stored.result = result
async def get_result(self, task_id: str) -> Result | None:
"""Get the stored result for a task."""
stored = self._tasks.get(task_id)
if stored is None:
return None
return stored.result
async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""List tasks with pagination."""
# Cleanup expired tasks on access
self._cleanup_expired()
all_task_ids = list(self._tasks.keys())
start_index = 0
if cursor is not None:
try:
cursor_index = all_task_ids.index(cursor)
start_index = cursor_index + 1
except ValueError:
raise ValueError(f"Invalid cursor: {cursor}")
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
# Determine next cursor
next_cursor = None
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
next_cursor = page_task_ids[-1]
return tasks, next_cursor
async def delete_task(self, task_id: str) -> bool:
"""Delete a task."""
if task_id not in self._tasks:
return False
del self._tasks[task_id]
return True
async def wait_for_update(self, task_id: str) -> None:
"""Wait until the task status changes."""
if task_id not in self._tasks:
raise ValueError(f"Task with ID {task_id} not found")
# Create a fresh event for waiting (anyio.Event can't be cleared)
self._update_events[task_id] = anyio.Event()
event = self._update_events[task_id]
await event.wait()
async def notify_update(self, task_id: str) -> None:
"""Signal that a task has been updated."""
if task_id in self._update_events:
self._update_events[task_id].set()
# --- Testing/debugging helpers ---
def cleanup(self) -> None:
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
self._tasks.clear()
self._update_events.clear()
def get_all_tasks(self) -> list[Task]:
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
self._cleanup_expired()
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]

View File

@@ -0,0 +1,241 @@
"""
TaskMessageQueue - FIFO queue for task-related messages.
This implements the core message queue pattern from the MCP Tasks spec.
When a handler needs to send a request (like elicitation) during a task-augmented
request, the message is enqueued instead of sent directly. Messages are delivered
to the client only through the `tasks/result` endpoint.
This pattern enables:
1. Decoupling request handling from message delivery
2. Proper bidirectional communication via the tasks/result stream
3. Automatic status management (working <-> input_required)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Literal
import anyio
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.types import JSONRPCNotification, JSONRPCRequest, RequestId
@dataclass
class QueuedMessage:
"""
A message queued for delivery via tasks/result.
Messages are stored with their type and a resolver for requests
that expect responses.
"""
type: Literal["request", "notification"]
"""Whether this is a request (expects response) or notification (one-way)."""
message: JSONRPCRequest | JSONRPCNotification
"""The JSON-RPC message to send."""
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
"""When the message was enqueued."""
resolver: Resolver[dict[str, Any]] | None = None
"""Resolver to set when response arrives (only for requests)."""
original_request_id: RequestId | None = None
"""The original request ID used internally, for routing responses back."""
class TaskMessageQueue(ABC):
"""
Abstract interface for task message queuing.
This is a FIFO queue that stores messages to be delivered via `tasks/result`.
When a task-augmented handler calls elicit() or sends a notification, the
message is enqueued here instead of being sent directly to the client.
The `tasks/result` handler then dequeues and sends these messages through
the transport, with `relatedRequestId` set to the tasks/result request ID
so responses are routed correctly.
Implementations can use in-memory storage, Redis, etc.
"""
@abstractmethod
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
"""
Add a message to the queue for a task.
Args:
task_id: The task identifier
message: The message to enqueue
"""
@abstractmethod
async def dequeue(self, task_id: str) -> QueuedMessage | None:
"""
Remove and return the next message from the queue.
Args:
task_id: The task identifier
Returns:
The next message, or None if queue is empty
"""
@abstractmethod
async def peek(self, task_id: str) -> QueuedMessage | None:
"""
Return the next message without removing it.
Args:
task_id: The task identifier
Returns:
The next message, or None if queue is empty
"""
@abstractmethod
async def is_empty(self, task_id: str) -> bool:
"""
Check if the queue is empty for a task.
Args:
task_id: The task identifier
Returns:
True if no messages are queued
"""
@abstractmethod
async def clear(self, task_id: str) -> list[QueuedMessage]:
"""
Remove and return all messages from the queue.
This is useful for cleanup when a task is cancelled or completed.
Args:
task_id: The task identifier
Returns:
All queued messages (may be empty)
"""
@abstractmethod
async def wait_for_message(self, task_id: str) -> None:
"""
Wait until a message is available in the queue.
This blocks until either:
1. A message is enqueued for this task
2. The wait is cancelled
Args:
task_id: The task identifier
"""
@abstractmethod
async def notify_message_available(self, task_id: str) -> None:
"""
Signal that a message is available for a task.
This wakes up any coroutines waiting in wait_for_message().
Args:
task_id: The task identifier
"""
class InMemoryTaskMessageQueue(TaskMessageQueue):
"""
In-memory implementation of TaskMessageQueue.
This is suitable for single-process servers. For distributed systems,
implement TaskMessageQueue with Redis, RabbitMQ, etc.
Features:
- FIFO ordering per task
- Async wait for message availability
- Thread-safe for single-process async use
"""
def __init__(self) -> None:
self._queues: dict[str, list[QueuedMessage]] = {}
self._events: dict[str, anyio.Event] = {}
def _get_queue(self, task_id: str) -> list[QueuedMessage]:
"""Get or create the queue for a task."""
if task_id not in self._queues:
self._queues[task_id] = []
return self._queues[task_id]
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
"""Add a message to the queue."""
queue = self._get_queue(task_id)
queue.append(message)
# Signal that a message is available
await self.notify_message_available(task_id)
async def dequeue(self, task_id: str) -> QueuedMessage | None:
"""Remove and return the next message."""
queue = self._get_queue(task_id)
if not queue:
return None
return queue.pop(0)
async def peek(self, task_id: str) -> QueuedMessage | None:
"""Return the next message without removing it."""
queue = self._get_queue(task_id)
if not queue:
return None
return queue[0]
async def is_empty(self, task_id: str) -> bool:
"""Check if the queue is empty."""
queue = self._get_queue(task_id)
return len(queue) == 0
async def clear(self, task_id: str) -> list[QueuedMessage]:
"""Remove and return all messages."""
queue = self._get_queue(task_id)
messages = list(queue)
queue.clear()
return messages
async def wait_for_message(self, task_id: str) -> None:
"""Wait until a message is available."""
# Check if there are already messages
if not await self.is_empty(task_id):
return
# Create a fresh event for waiting (anyio.Event can't be cleared)
self._events[task_id] = anyio.Event()
event = self._events[task_id]
# Double-check after creating event (avoid race condition)
if not await self.is_empty(task_id):
return
# Wait for a new message
await event.wait()
async def notify_message_available(self, task_id: str) -> None:
"""Signal that a message is available."""
if task_id in self._events:
self._events[task_id].set()
def cleanup(self, task_id: str | None = None) -> None:
"""
Clean up queues and events.
Args:
task_id: If provided, clean up only this task. Otherwise clean up all.
"""
if task_id is not None:
self._queues.pop(task_id, None)
self._events.pop(task_id, None)
else:
self._queues.clear()
self._events.clear()

View File

@@ -0,0 +1,45 @@
"""
Shared polling utilities for task operations.
This module provides generic polling logic that works for both client→server
and server→client task polling.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import AsyncIterator, Awaitable, Callable
import anyio
from mcp.shared.experimental.tasks.helpers import is_terminal
from mcp.types import GetTaskResult
async def poll_until_terminal(
get_task: Callable[[str], Awaitable[GetTaskResult]],
task_id: str,
default_interval_ms: int = 500,
) -> AsyncIterator[GetTaskResult]:
"""
Poll a task until it reaches terminal status.
This is a generic utility that works for both client→server and server→client
polling. The caller provides the get_task function appropriate for their direction.
Args:
get_task: Async function that takes task_id and returns GetTaskResult
task_id: The task to poll
default_interval_ms: Fallback poll interval if server doesn't specify
Yields:
GetTaskResult for each poll
"""
while True:
status = await get_task(task_id)
yield status
if is_terminal(status.status):
break
interval_ms = status.pollInterval if status.pollInterval is not None else default_interval_ms
await anyio.sleep(interval_ms / 1000)

View File

@@ -0,0 +1,60 @@
"""
Resolver - An anyio-compatible future-like object for async result passing.
This provides a simple way to pass a result (or exception) from one coroutine
to another without depending on asyncio.Future.
"""
from typing import Generic, TypeVar, cast
import anyio
T = TypeVar("T")
class Resolver(Generic[T]):
"""
A simple resolver for passing results between coroutines.
Unlike asyncio.Future, this works with any anyio-compatible async backend.
Usage:
resolver: Resolver[str] = Resolver()
# In one coroutine:
resolver.set_result("hello")
# In another coroutine:
result = await resolver.wait() # returns "hello"
"""
def __init__(self) -> None:
self._event = anyio.Event()
self._value: T | None = None
self._exception: BaseException | None = None
def set_result(self, value: T) -> None:
"""Set the result value and wake up waiters."""
if self._event.is_set():
raise RuntimeError("Resolver already completed")
self._value = value
self._event.set()
def set_exception(self, exc: BaseException) -> None:
"""Set an exception and wake up waiters."""
if self._event.is_set():
raise RuntimeError("Resolver already completed")
self._exception = exc
self._event.set()
async def wait(self) -> T:
"""Wait for the result and return it, or raise the exception."""
await self._event.wait()
if self._exception is not None:
raise self._exception
# If we reach here, set_result() was called, so _value is set
return cast(T, self._value)
def done(self) -> bool:
"""Return True if the resolver has been completed."""
return self._event.is_set()

View File

@@ -0,0 +1,156 @@
"""
TaskStore - Abstract interface for task state storage.
"""
from abc import ABC, abstractmethod
from mcp.types import Result, Task, TaskMetadata, TaskStatus
class TaskStore(ABC):
"""
Abstract interface for task state storage.
This is a pure storage interface - it doesn't manage execution.
Implementations can use in-memory storage, databases, Redis, etc.
All methods are async to support various backends.
"""
@abstractmethod
async def create_task(
self,
metadata: TaskMetadata,
task_id: str | None = None,
) -> Task:
"""
Create a new task.
Args:
metadata: Task metadata (ttl, etc.)
task_id: Optional task ID. If None, implementation should generate one.
Returns:
The created Task with status="working"
Raises:
ValueError: If task_id already exists
"""
@abstractmethod
async def get_task(self, task_id: str) -> Task | None:
"""
Get a task by ID.
Args:
task_id: The task identifier
Returns:
The Task, or None if not found
"""
@abstractmethod
async def update_task(
self,
task_id: str,
status: TaskStatus | None = None,
status_message: str | None = None,
) -> Task:
"""
Update a task's status and/or message.
Args:
task_id: The task identifier
status: New status (if changing)
status_message: New status message (if changing)
Returns:
The updated Task
Raises:
ValueError: If task not found
ValueError: If attempting to transition from a terminal status
(completed, failed, cancelled). Per spec, terminal states
MUST NOT transition to any other status.
"""
@abstractmethod
async def store_result(self, task_id: str, result: Result) -> None:
"""
Store the result for a task.
Args:
task_id: The task identifier
result: The result to store
Raises:
ValueError: If task not found
"""
@abstractmethod
async def get_result(self, task_id: str) -> Result | None:
"""
Get the stored result for a task.
Args:
task_id: The task identifier
Returns:
The stored Result, or None if not available
"""
@abstractmethod
async def list_tasks(
self,
cursor: str | None = None,
) -> tuple[list[Task], str | None]:
"""
List tasks with pagination.
Args:
cursor: Optional cursor for pagination
Returns:
Tuple of (tasks, next_cursor). next_cursor is None if no more pages.
"""
@abstractmethod
async def delete_task(self, task_id: str) -> bool:
"""
Delete a task.
Args:
task_id: The task identifier
Returns:
True if deleted, False if not found
"""
@abstractmethod
async def wait_for_update(self, task_id: str) -> None:
"""
Wait until the task status changes.
This blocks until either:
1. The task status changes
2. The wait is cancelled
Used by tasks/result to wait for task completion or status changes.
Args:
task_id: The task identifier
Raises:
ValueError: If task not found
"""
@abstractmethod
async def notify_update(self, task_id: str) -> None:
"""
Signal that a task has been updated.
This wakes up any coroutines waiting in wait_for_update().
Args:
task_id: The task identifier
"""