Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Server-side experimental features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Import directly from submodules:
|
||||
- mcp.server.experimental.task_context.ServerTaskContext
|
||||
- mcp.server.experimental.task_support.TaskSupport
|
||||
- mcp.server.experimental.task_result_handler.TaskResultHandler
|
||||
- mcp.server.experimental.request_context.Experimental
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Experimental request context features.
|
||||
|
||||
This module provides the Experimental class which gives access to experimental
|
||||
features within a request context, such as task-augmented request handling.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.experimental.task_context import ServerTaskContext
|
||||
from mcp.server.experimental.task_support import TaskSupport
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal
|
||||
from mcp.types import (
|
||||
METHOD_NOT_FOUND,
|
||||
TASK_FORBIDDEN,
|
||||
TASK_REQUIRED,
|
||||
ClientCapabilities,
|
||||
CreateTaskResult,
|
||||
ErrorData,
|
||||
Result,
|
||||
TaskExecutionMode,
|
||||
TaskMetadata,
|
||||
Tool,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experimental:
|
||||
"""
|
||||
Experimental features context for task-augmented requests.
|
||||
|
||||
Provides helpers for validating task execution compatibility and
|
||||
running tasks with automatic lifecycle management.
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
task_metadata: TaskMetadata | None = None
|
||||
_client_capabilities: ClientCapabilities | None = field(default=None, repr=False)
|
||||
_session: ServerSession | None = field(default=None, repr=False)
|
||||
_task_support: TaskSupport | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def is_task(self) -> bool:
|
||||
"""Check if this request is task-augmented."""
|
||||
return self.task_metadata is not None
|
||||
|
||||
@property
|
||||
def client_supports_tasks(self) -> bool:
|
||||
"""Check if the client declared task support."""
|
||||
if self._client_capabilities is None:
|
||||
return False
|
||||
return self._client_capabilities.tasks is not None
|
||||
|
||||
def validate_task_mode(
|
||||
self,
|
||||
tool_task_mode: TaskExecutionMode | None,
|
||||
*,
|
||||
raise_error: bool = True,
|
||||
) -> ErrorData | None:
|
||||
"""
|
||||
Validate that the request is compatible with the tool's task execution mode.
|
||||
|
||||
Per MCP spec:
|
||||
- "required": Clients MUST invoke as task. Server returns -32601 if not.
|
||||
- "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do.
|
||||
- "optional": Either is acceptable.
|
||||
|
||||
Args:
|
||||
tool_task_mode: The tool's execution.taskSupport value
|
||||
("forbidden", "optional", "required", or None)
|
||||
raise_error: If True, raises McpError on validation failure. If False, returns ErrorData.
|
||||
|
||||
Returns:
|
||||
None if valid, ErrorData if invalid and raise_error=False
|
||||
|
||||
Raises:
|
||||
McpError: If invalid and raise_error=True
|
||||
"""
|
||||
|
||||
mode = tool_task_mode or TASK_FORBIDDEN
|
||||
|
||||
error: ErrorData | None = None
|
||||
|
||||
if mode == TASK_REQUIRED and not self.is_task:
|
||||
error = ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="This tool requires task-augmented invocation",
|
||||
)
|
||||
elif mode == TASK_FORBIDDEN and self.is_task:
|
||||
error = ErrorData(
|
||||
code=METHOD_NOT_FOUND,
|
||||
message="This tool does not support task-augmented invocation",
|
||||
)
|
||||
|
||||
if error is not None and raise_error:
|
||||
raise McpError(error)
|
||||
|
||||
return error
|
||||
|
||||
def validate_for_tool(
|
||||
self,
|
||||
tool: Tool,
|
||||
*,
|
||||
raise_error: bool = True,
|
||||
) -> ErrorData | None:
|
||||
"""
|
||||
Validate that the request is compatible with the given tool.
|
||||
|
||||
Convenience wrapper around validate_task_mode that extracts the mode from a Tool.
|
||||
|
||||
Args:
|
||||
tool: The Tool definition
|
||||
raise_error: If True, raises McpError on validation failure.
|
||||
|
||||
Returns:
|
||||
None if valid, ErrorData if invalid and raise_error=False
|
||||
"""
|
||||
mode = tool.execution.taskSupport if tool.execution else None
|
||||
return self.validate_task_mode(mode, raise_error=raise_error)
|
||||
|
||||
def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool:
|
||||
"""
|
||||
Check if this client can use a tool with the given task mode.
|
||||
|
||||
Useful for filtering tool lists or providing warnings.
|
||||
Returns False if tool requires "required" but client doesn't support tasks.
|
||||
|
||||
Args:
|
||||
tool_task_mode: The tool's execution.taskSupport value
|
||||
|
||||
Returns:
|
||||
True if the client can use this tool, False otherwise
|
||||
"""
|
||||
mode = tool_task_mode or TASK_FORBIDDEN
|
||||
if mode == TASK_REQUIRED and not self.client_supports_tasks:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def run_task(
|
||||
self,
|
||||
work: Callable[[ServerTaskContext], Awaitable[Result]],
|
||||
*,
|
||||
task_id: str | None = None,
|
||||
model_immediate_response: str | None = None,
|
||||
) -> CreateTaskResult:
|
||||
"""
|
||||
Create a task, spawn background work, and return CreateTaskResult immediately.
|
||||
|
||||
This is the recommended way to handle task-augmented tool calls. It:
|
||||
1. Creates a task in the store
|
||||
2. Spawns the work function in a background task
|
||||
3. Returns CreateTaskResult immediately
|
||||
|
||||
The work function receives a ServerTaskContext with:
|
||||
- elicit() for sending elicitation requests
|
||||
- create_message() for sampling requests
|
||||
- update_status() for progress updates
|
||||
- complete()/fail() for finishing the task
|
||||
|
||||
When work() returns a Result, the task is auto-completed with that result.
|
||||
If work() raises an exception, the task is auto-failed.
|
||||
|
||||
Args:
|
||||
work: Async function that does the actual work
|
||||
task_id: Optional task ID (generated if not provided)
|
||||
model_immediate_response: Optional string to include in _meta as
|
||||
io.modelcontextprotocol/model-immediate-response
|
||||
|
||||
Returns:
|
||||
CreateTaskResult to return to the client
|
||||
|
||||
Raises:
|
||||
RuntimeError: If task support is not enabled or task_metadata is missing
|
||||
|
||||
Example:
|
||||
@server.call_tool()
|
||||
async def handle_tool(name: str, args: dict):
|
||||
ctx = server.request_context
|
||||
|
||||
async def work(task: ServerTaskContext) -> CallToolResult:
|
||||
result = await task.elicit(
|
||||
message="Are you sure?",
|
||||
requestedSchema={"type": "object", ...}
|
||||
)
|
||||
confirmed = result.content.get("confirm", False)
|
||||
return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")])
|
||||
|
||||
return await ctx.experimental.run_task(work)
|
||||
|
||||
WARNING: This API is experimental and may change without notice.
|
||||
"""
|
||||
if self._task_support is None:
|
||||
raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.")
|
||||
if self._session is None:
|
||||
raise RuntimeError("Session not available.")
|
||||
if self.task_metadata is None:
|
||||
raise RuntimeError(
|
||||
"Request is not task-augmented (no task field in params). "
|
||||
"The client must send a task-augmented request."
|
||||
)
|
||||
|
||||
support = self._task_support
|
||||
# Access task_group via TaskSupport - raises if not in run() context
|
||||
task_group = support.task_group
|
||||
|
||||
task = await support.store.create_task(self.task_metadata, task_id)
|
||||
|
||||
task_ctx = ServerTaskContext(
|
||||
task=task,
|
||||
store=support.store,
|
||||
session=self._session,
|
||||
queue=support.queue,
|
||||
handler=support.handler,
|
||||
)
|
||||
|
||||
async def execute() -> None:
|
||||
try:
|
||||
result = await work(task_ctx)
|
||||
if not is_terminal(task_ctx.task.status):
|
||||
await task_ctx.complete(result)
|
||||
except Exception as e:
|
||||
if not is_terminal(task_ctx.task.status):
|
||||
await task_ctx.fail(str(e))
|
||||
|
||||
task_group.start_soon(execute)
|
||||
|
||||
meta: dict[str, Any] | None = None
|
||||
if model_immediate_response is not None:
|
||||
meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response}
|
||||
|
||||
return CreateTaskResult(task=task, **{"_meta": meta} if meta else {})
|
||||
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Experimental server session features for server→client task operations.
|
||||
|
||||
This module provides the server-side equivalent of ExperimentalClientFeatures,
|
||||
allowing the server to send task-augmented requests to the client and poll for results.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
|
||||
from mcp.shared.experimental.tasks.capabilities import (
|
||||
require_task_augmented_elicitation,
|
||||
require_task_augmented_sampling,
|
||||
)
|
||||
from mcp.shared.experimental.tasks.polling import poll_until_terminal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.server.session import ServerSession
|
||||
|
||||
ResultT = TypeVar("ResultT", bound=types.Result)
|
||||
|
||||
|
||||
class ExperimentalServerSessionFeatures:
|
||||
"""
|
||||
Experimental server session features for server→client task operations.
|
||||
|
||||
This provides the server-side equivalent of ExperimentalClientFeatures,
|
||||
allowing the server to send task-augmented requests to the client and
|
||||
poll for results.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Access via session.experimental:
|
||||
result = await session.experimental.elicit_as_task(...)
|
||||
"""
|
||||
|
||||
def __init__(self, session: "ServerSession") -> None:
|
||||
self._session = session
|
||||
|
||||
async def get_task(self, task_id: str) -> types.GetTaskResult:
|
||||
"""
|
||||
Send tasks/get to the client to get task status.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
GetTaskResult containing the task status
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))),
|
||||
types.GetTaskResult,
|
||||
)
|
||||
|
||||
async def get_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result_type: type[ResultT],
|
||||
) -> ResultT:
|
||||
"""
|
||||
Send tasks/result to the client to retrieve the final result.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result_type: The expected result type
|
||||
|
||||
Returns:
|
||||
The task result, validated against result_type
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))),
|
||||
result_type,
|
||||
)
|
||||
|
||||
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
|
||||
"""
|
||||
Poll a client task until it reaches terminal status.
|
||||
|
||||
Yields GetTaskResult for each poll, allowing the caller to react to
|
||||
status changes. Exits when task reaches a terminal status.
|
||||
|
||||
Respects the pollInterval hint from the client.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
"""
|
||||
async for status in poll_until_terminal(self.get_task, task_id):
|
||||
yield status
|
||||
|
||||
async def elicit_as_task(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: types.ElicitRequestedSchema,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
) -> types.ElicitResult:
|
||||
"""
|
||||
Send a task-augmented elicitation to the client and poll until complete.
|
||||
|
||||
The client will create a local task, process the elicitation asynchronously,
|
||||
and return the result when ready. This method handles the full flow:
|
||||
1. Send elicitation with task field
|
||||
2. Receive CreateTaskResult from client
|
||||
3. Poll client's task until terminal
|
||||
4. Retrieve and return the final ElicitResult
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response
|
||||
ttl: Task time-to-live in milliseconds
|
||||
|
||||
Returns:
|
||||
The client's elicitation response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_elicitation(client_caps)
|
||||
|
||||
create_result = await self._session.send_request(
|
||||
types.ServerRequest(
|
||||
types.ElicitRequest(
|
||||
params=types.ElicitRequestFormParams(
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
)
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
task_id = create_result.task.taskId
|
||||
|
||||
async for _ in self.poll_task(task_id):
|
||||
pass
|
||||
|
||||
return await self.get_task_result(task_id, types.ElicitResult)
|
||||
|
||||
async def create_message_as_task(
|
||||
self,
|
||||
messages: list[types.SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
ttl: int = 60000,
|
||||
system_prompt: str | None = None,
|
||||
include_context: types.IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
tools: list[types.Tool] | None = None,
|
||||
tool_choice: types.ToolChoice | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""
|
||||
Send a task-augmented sampling request and poll until complete.
|
||||
|
||||
The client will create a local task, process the sampling request
|
||||
asynchronously, and return the result when ready.
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
ttl: Task time-to-live in milliseconds
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_sampling(client_caps)
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
create_result = await self._session.send_request(
|
||||
types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
params=types.CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
maxTokens=max_tokens,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
modelPreferences=model_preferences,
|
||||
tools=tools,
|
||||
toolChoice=tool_choice,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
)
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
task_id = create_result.task.taskId
|
||||
|
||||
async for _ in self.poll_task(task_id):
|
||||
pass
|
||||
|
||||
return await self.get_task_result(task_id, types.CreateMessageResult)
|
||||
@@ -0,0 +1,612 @@
|
||||
"""
|
||||
ServerTaskContext - Server-integrated task context with elicitation and sampling.
|
||||
|
||||
This wraps the pure TaskContext and adds server-specific functionality:
|
||||
- Elicitation (task.elicit())
|
||||
- Sampling (task.create_message())
|
||||
- Status notifications
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.experimental.task_result_handler import TaskResultHandler
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.capabilities import (
|
||||
require_task_augmented_elicitation,
|
||||
require_task_augmented_sampling,
|
||||
)
|
||||
from mcp.shared.experimental.tasks.context import TaskContext
|
||||
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import (
|
||||
INVALID_REQUEST,
|
||||
TASK_STATUS_INPUT_REQUIRED,
|
||||
TASK_STATUS_WORKING,
|
||||
ClientCapabilities,
|
||||
CreateMessageResult,
|
||||
CreateTaskResult,
|
||||
ElicitationCapability,
|
||||
ElicitRequestedSchema,
|
||||
ElicitResult,
|
||||
ErrorData,
|
||||
IncludeContext,
|
||||
ModelPreferences,
|
||||
RequestId,
|
||||
Result,
|
||||
SamplingCapability,
|
||||
SamplingMessage,
|
||||
ServerNotification,
|
||||
Task,
|
||||
TaskMetadata,
|
||||
TaskStatusNotification,
|
||||
TaskStatusNotificationParams,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
|
||||
class ServerTaskContext:
|
||||
"""
|
||||
Server-integrated task context with elicitation and sampling.
|
||||
|
||||
This wraps a pure TaskContext and adds server-specific functionality:
|
||||
- elicit() for sending elicitation requests to the client
|
||||
- create_message() for sampling requests
|
||||
- Status notifications via the session
|
||||
|
||||
Example:
|
||||
async def my_task_work(task: ServerTaskContext) -> CallToolResult:
|
||||
await task.update_status("Starting...")
|
||||
|
||||
result = await task.elicit(
|
||||
message="Continue?",
|
||||
requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}}
|
||||
)
|
||||
|
||||
if result.content.get("ok"):
|
||||
return CallToolResult(content=[TextContent(text="Done!")])
|
||||
else:
|
||||
return CallToolResult(content=[TextContent(text="Cancelled")])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: Task,
|
||||
store: TaskStore,
|
||||
session: ServerSession,
|
||||
queue: TaskMessageQueue,
|
||||
handler: TaskResultHandler | None = None,
|
||||
):
|
||||
"""
|
||||
Create a ServerTaskContext.
|
||||
|
||||
Args:
|
||||
task: The Task object
|
||||
store: The task store
|
||||
session: The server session
|
||||
queue: The message queue for elicitation/sampling
|
||||
handler: The result handler for response routing (required for elicit/create_message)
|
||||
"""
|
||||
self._ctx = TaskContext(task=task, store=store)
|
||||
self._session = session
|
||||
self._queue = queue
|
||||
self._handler = handler
|
||||
self._store = store
|
||||
|
||||
# Delegate pure properties to inner context
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""The task identifier."""
|
||||
return self._ctx.task_id
|
||||
|
||||
@property
|
||||
def task(self) -> Task:
|
||||
"""The current task state."""
|
||||
return self._ctx.task
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Whether cancellation has been requested."""
|
||||
return self._ctx.is_cancelled
|
||||
|
||||
def request_cancellation(self) -> None:
|
||||
"""Request cancellation of this task."""
|
||||
self._ctx.request_cancellation()
|
||||
|
||||
# Enhanced methods with notifications
|
||||
|
||||
async def update_status(self, message: str, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Update the task's status message.
|
||||
|
||||
Args:
|
||||
message: The new status message
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.update_status(message)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def complete(self, result: Result, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Mark the task as completed with the given result.
|
||||
|
||||
Args:
|
||||
result: The task result
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.complete(result)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def fail(self, error: str, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Mark the task as failed with an error message.
|
||||
|
||||
Args:
|
||||
error: The error message
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.fail(error)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def _send_notification(self) -> None:
|
||||
"""Send a task status notification to the client."""
|
||||
task = self._ctx.task
|
||||
await self._session.send_notification(
|
||||
ServerNotification(
|
||||
TaskStatusNotification(
|
||||
params=TaskStatusNotificationParams(
|
||||
taskId=task.taskId,
|
||||
status=task.status,
|
||||
statusMessage=task.statusMessage,
|
||||
createdAt=task.createdAt,
|
||||
lastUpdatedAt=task.lastUpdatedAt,
|
||||
ttl=task.ttl,
|
||||
pollInterval=task.pollInterval,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Server-specific methods: elicitation and sampling
|
||||
|
||||
def _check_elicitation_capability(self) -> None:
|
||||
"""Check if the client supports elicitation."""
|
||||
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support elicitation capability",
|
||||
)
|
||||
)
|
||||
|
||||
def _check_sampling_capability(self) -> None:
|
||||
"""Check if the client supports sampling."""
|
||||
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support sampling capability",
|
||||
)
|
||||
)
|
||||
|
||||
async def elicit(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: ElicitRequestedSchema,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send an elicitation request via the task message queue.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the elicitation request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
|
||||
Returns:
|
||||
The client's response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support elicitation capability
|
||||
"""
|
||||
self._check_elicitation_capability()
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return ElicitResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
# Coverage can't track async exception handlers reliably.
|
||||
# This path is tested in test_elicit_restores_status_on_cancellation
|
||||
# which verifies status is restored to "working" after cancellation.
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def elicit_url(
|
||||
self,
|
||||
message: str,
|
||||
url: str,
|
||||
elicitation_id: str,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send a URL mode elicitation request via the task message queue.
|
||||
|
||||
This directs the user to an external URL for out-of-band interactions
|
||||
like OAuth flows, credential collection, or payment processing.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the elicitation request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
message: Human-readable explanation of why the interaction is needed
|
||||
url: The URL the user should navigate to
|
||||
elicitation_id: Unique identifier for tracking this elicitation
|
||||
|
||||
Returns:
|
||||
The client's response indicating acceptance, decline, or cancellation
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support elicitation capability
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
self._check_elicitation_capability()
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
url=url,
|
||||
elicitation_id=elicitation_id,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return ElicitResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""
|
||||
Send a sampling request via the task message queue.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the sampling request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support sampling capability or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
"""
|
||||
self._check_sampling_capability()
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
include_context=include_context,
|
||||
temperature=temperature,
|
||||
stop_sequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
model_preferences=model_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return CreateMessageResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
# Coverage can't track async exception handlers reliably.
|
||||
# This path is tested in test_create_message_restores_status_on_cancellation
|
||||
# which verifies status is restored to "working" after cancellation.
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def elicit_as_task(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: ElicitRequestedSchema,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send a task-augmented elicitation via the queue, then poll client.
|
||||
|
||||
This is for use inside a task-augmented tool call when you want the client
|
||||
to handle the elicitation as its own task. The elicitation request is queued
|
||||
and delivered when the client calls tasks/result. After the client responds
|
||||
with CreateTaskResult, we poll the client's task until complete.
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
ttl: Task time-to-live in milliseconds for the client's task
|
||||
|
||||
Returns:
|
||||
The client's elicitation response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_elicitation(client_caps)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit_as_task()")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
related_task_id=self.task_id,
|
||||
task=TaskMetadata(ttl=ttl),
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for initial response (CreateTaskResult from client)
|
||||
response_data = await resolver.wait()
|
||||
create_result = CreateTaskResult.model_validate(response_data)
|
||||
client_task_id = create_result.task.taskId
|
||||
|
||||
# Poll the client's task using session.experimental
|
||||
async for _ in self._session.experimental.poll_task(client_task_id):
|
||||
pass
|
||||
|
||||
# Get final result from client
|
||||
result = await self._session.experimental.get_task_result(
|
||||
client_task_id,
|
||||
ElicitResult,
|
||||
)
|
||||
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return result
|
||||
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def create_message_as_task(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
ttl: int = 60000,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""
|
||||
Send a task-augmented sampling request via the queue, then poll client.
|
||||
|
||||
This is for use inside a task-augmented tool call when you want the client
|
||||
to handle the sampling as its own task. The request is queued and delivered
|
||||
when the client calls tasks/result. After the client responds with
|
||||
CreateTaskResult, we poll the client's task until complete.
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
ttl: Task time-to-live in milliseconds for the client's task
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_sampling(client_caps)
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for create_message_as_task()")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build request WITH task field for task-augmented sampling
|
||||
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
include_context=include_context,
|
||||
temperature=temperature,
|
||||
stop_sequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
model_preferences=model_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
related_task_id=self.task_id,
|
||||
task=TaskMetadata(ttl=ttl),
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for initial response (CreateTaskResult from client)
|
||||
response_data = await resolver.wait()
|
||||
create_result = CreateTaskResult.model_validate(response_data)
|
||||
client_task_id = create_result.task.taskId
|
||||
|
||||
# Poll the client's task using session.experimental
|
||||
async for _ in self._session.experimental.poll_task(client_task_id):
|
||||
pass
|
||||
|
||||
# Get final result from client
|
||||
result = await self._session.experimental.get_task_result(
|
||||
client_task_id,
|
||||
CreateMessageResult,
|
||||
)
|
||||
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return result
|
||||
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
TaskResultHandler - Integrated handler for tasks/result endpoint.
|
||||
|
||||
This implements the dequeue-send-wait pattern from the MCP Tasks spec:
|
||||
1. Dequeue all pending messages for the task
|
||||
2. Send them to the client via transport with relatedRequestId routing
|
||||
3. Wait if task is not in terminal state
|
||||
4. Return final result when task completes
|
||||
|
||||
This is the core of the task message queue pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
|
||||
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
INVALID_PARAMS,
|
||||
ErrorData,
|
||||
GetTaskPayloadRequest,
|
||||
GetTaskPayloadResult,
|
||||
JSONRPCMessage,
|
||||
RelatedTaskMetadata,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskResultHandler:
|
||||
"""
|
||||
Handler for tasks/result that implements the message queue pattern.
|
||||
|
||||
This handler:
|
||||
1. Dequeues pending messages (elicitations, notifications) for the task
|
||||
2. Sends them to the client via the response stream
|
||||
3. Waits for responses and resolves them back to callers
|
||||
4. Blocks until task reaches terminal state
|
||||
5. Returns the final result
|
||||
|
||||
Usage:
|
||||
# Create handler with store and queue
|
||||
handler = TaskResultHandler(task_store, message_queue)
|
||||
|
||||
# Register it with the server
|
||||
@server.experimental.get_task_result()
|
||||
async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult:
|
||||
ctx = server.request_context
|
||||
return await handler.handle(req, ctx.session, ctx.request_id)
|
||||
|
||||
# Or use the convenience method
|
||||
handler.register(server)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store: TaskStore,
|
||||
queue: TaskMessageQueue,
|
||||
):
|
||||
self._store = store
|
||||
self._queue = queue
|
||||
# Map from internal request ID to resolver for routing responses
|
||||
self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {}
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
session: ServerSession,
|
||||
message: SessionMessage,
|
||||
) -> None:
|
||||
"""
|
||||
Send a message via the session.
|
||||
|
||||
This is a helper for delivering queued task messages.
|
||||
"""
|
||||
await session.send_message(message)
|
||||
|
||||
async def handle(
|
||||
self,
|
||||
request: GetTaskPayloadRequest,
|
||||
session: ServerSession,
|
||||
request_id: RequestId,
|
||||
) -> GetTaskPayloadResult:
|
||||
"""
|
||||
Handle a tasks/result request.
|
||||
|
||||
This implements the dequeue-send-wait loop:
|
||||
1. Dequeue all pending messages
|
||||
2. Send each via transport with relatedRequestId = this request's ID
|
||||
3. If task not terminal, wait for status change
|
||||
4. Loop until task is terminal
|
||||
5. Return final result
|
||||
|
||||
Args:
|
||||
request: The GetTaskPayloadRequest
|
||||
session: The server session for sending messages
|
||||
request_id: The request ID for relatedRequestId routing
|
||||
|
||||
Returns:
|
||||
GetTaskPayloadResult with the task's final payload
|
||||
"""
|
||||
task_id = request.params.taskId
|
||||
|
||||
while True:
|
||||
task = await self._store.get_task(task_id)
|
||||
if task is None:
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_PARAMS,
|
||||
message=f"Task not found: {task_id}",
|
||||
)
|
||||
)
|
||||
|
||||
await self._deliver_queued_messages(task_id, session, request_id)
|
||||
|
||||
# If task is terminal, return result
|
||||
if is_terminal(task.status):
|
||||
result = await self._store.get_result(task_id)
|
||||
# GetTaskPayloadResult is a Result with extra="allow"
|
||||
# The stored result contains the actual payload data
|
||||
# Per spec: tasks/result MUST include _meta with related-task metadata
|
||||
related_task = RelatedTaskMetadata(taskId=task_id)
|
||||
related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)}
|
||||
if result is not None:
|
||||
result_data = result.model_dump(by_alias=True)
|
||||
existing_meta: dict[str, Any] = result_data.get("_meta") or {}
|
||||
result_data["_meta"] = {**existing_meta, **related_task_meta}
|
||||
return GetTaskPayloadResult.model_validate(result_data)
|
||||
return GetTaskPayloadResult.model_validate({"_meta": related_task_meta})
|
||||
|
||||
# Wait for task update (status change or new messages)
|
||||
await self._wait_for_task_update(task_id)
|
||||
|
||||
async def _deliver_queued_messages(
|
||||
self,
|
||||
task_id: str,
|
||||
session: ServerSession,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""
|
||||
Dequeue and send all pending messages for a task.
|
||||
|
||||
Each message is sent via the session's write stream with
|
||||
relatedRequestId set so responses route back to this stream.
|
||||
"""
|
||||
while True:
|
||||
message = await self._queue.dequeue(task_id)
|
||||
if message is None:
|
||||
break
|
||||
|
||||
# If this is a request (not notification), wait for response
|
||||
if message.type == "request" and message.resolver is not None:
|
||||
# Store the resolver so we can route the response back
|
||||
original_id = message.original_request_id
|
||||
if original_id is not None:
|
||||
self._pending_requests[original_id] = message.resolver
|
||||
|
||||
logger.debug("Delivering queued message for task %s: %s", task_id, message.type)
|
||||
|
||||
# Send the message with relatedRequestId for routing
|
||||
session_message = SessionMessage(
|
||||
message=JSONRPCMessage(message.message),
|
||||
metadata=ServerMessageMetadata(related_request_id=request_id),
|
||||
)
|
||||
await self.send_message(session, session_message)
|
||||
|
||||
async def _wait_for_task_update(self, task_id: str) -> None:
|
||||
"""
|
||||
Wait for task to be updated (status change or new message).
|
||||
|
||||
Races between store update and queue message - first one wins.
|
||||
"""
|
||||
async with anyio.create_task_group() as tg:
|
||||
|
||||
async def wait_for_store() -> None:
|
||||
try:
|
||||
await self._store.wait_for_update(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
async def wait_for_queue() -> None:
|
||||
try:
|
||||
await self._queue.wait_for_message(task_id)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
|
||||
tg.start_soon(wait_for_store)
|
||||
tg.start_soon(wait_for_queue)
|
||||
|
||||
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Route a response back to the waiting resolver.
|
||||
|
||||
This is called when a response arrives for a queued request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID from the response
|
||||
response: The response data
|
||||
|
||||
Returns:
|
||||
True if response was routed, False if no pending request
|
||||
"""
|
||||
resolver = self._pending_requests.pop(request_id, None)
|
||||
if resolver is not None and not resolver.done():
|
||||
resolver.set_result(response)
|
||||
return True
|
||||
return False
|
||||
|
||||
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
|
||||
"""
|
||||
Route an error back to the waiting resolver.
|
||||
|
||||
Args:
|
||||
request_id: The request ID from the error response
|
||||
error: The error data
|
||||
|
||||
Returns:
|
||||
True if error was routed, False if no pending request
|
||||
"""
|
||||
resolver = self._pending_requests.pop(request_id, None)
|
||||
if resolver is not None and not resolver.done():
|
||||
resolver.set_exception(McpError(error))
|
||||
return True
|
||||
return False
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
TaskSupport - Configuration for experimental task support.
|
||||
|
||||
This module provides the TaskSupport class which encapsulates all the
|
||||
infrastructure needed for task-augmented requests: store, queue, and handler.
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
|
||||
from mcp.server.experimental.task_result_handler import TaskResultHandler
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
|
||||
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskSupport:
|
||||
"""
|
||||
Configuration for experimental task support.
|
||||
|
||||
Encapsulates the task store, message queue, result handler, and task group
|
||||
for spawning background work.
|
||||
|
||||
When enabled on a server, this automatically:
|
||||
- Configures response routing for each session
|
||||
- Provides default handlers for task operations
|
||||
- Manages a task group for background task execution
|
||||
|
||||
Example:
|
||||
# Simple in-memory setup
|
||||
server.experimental.enable_tasks()
|
||||
|
||||
# Custom store/queue for distributed systems
|
||||
server.experimental.enable_tasks(
|
||||
store=RedisTaskStore(redis_url),
|
||||
queue=RedisTaskMessageQueue(redis_url),
|
||||
)
|
||||
"""
|
||||
|
||||
store: TaskStore
|
||||
queue: TaskMessageQueue
|
||||
handler: TaskResultHandler = field(init=False)
|
||||
_task_group: TaskGroup | None = field(init=False, default=None)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create the result handler from store and queue."""
|
||||
self.handler = TaskResultHandler(self.store, self.queue)
|
||||
|
||||
@property
|
||||
def task_group(self) -> TaskGroup:
|
||||
"""Get the task group for spawning background work.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If not within a run() context
|
||||
"""
|
||||
if self._task_group is None:
|
||||
raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.")
|
||||
return self._task_group
|
||||
|
||||
@asynccontextmanager
|
||||
async def run(self) -> AsyncIterator[None]:
|
||||
"""
|
||||
Run the task support lifecycle.
|
||||
|
||||
This creates a task group for spawning background task work.
|
||||
Called automatically by Server.run().
|
||||
|
||||
Usage:
|
||||
async with task_support.run():
|
||||
# Task group is now available
|
||||
...
|
||||
"""
|
||||
async with anyio.create_task_group() as tg:
|
||||
self._task_group = tg
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._task_group = None
|
||||
|
||||
def configure_session(self, session: ServerSession) -> None:
|
||||
"""
|
||||
Configure a session for task support.
|
||||
|
||||
This registers the result handler as a response router so that
|
||||
responses to queued requests (elicitation, sampling) are routed
|
||||
back to the waiting resolvers.
|
||||
|
||||
Called automatically by Server.run() for each new session.
|
||||
|
||||
Args:
|
||||
session: The session to configure
|
||||
"""
|
||||
session.add_response_router(self.handler)
|
||||
|
||||
@classmethod
|
||||
def in_memory(cls) -> "TaskSupport":
|
||||
"""
|
||||
Create in-memory task support.
|
||||
|
||||
Suitable for development, testing, and single-process servers.
|
||||
For distributed systems, provide custom store and queue implementations.
|
||||
|
||||
Returns:
|
||||
TaskSupport configured with in-memory store and queue
|
||||
"""
|
||||
return cls(
|
||||
store=InMemoryTaskStore(),
|
||||
queue=InMemoryTaskMessageQueue(),
|
||||
)
|
||||
Reference in New Issue
Block a user