Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,11 @@
"""
Server-side experimental features.
WARNING: These APIs are experimental and may change without notice.
Import directly from submodules:
- mcp.server.experimental.task_context.ServerTaskContext
- mcp.server.experimental.task_support.TaskSupport
- mcp.server.experimental.task_result_handler.TaskResultHandler
- mcp.server.experimental.request_context.Experimental
"""

View File

@@ -0,0 +1,238 @@
"""
Experimental request context features.
This module provides the Experimental class which gives access to experimental
features within a request context, such as task-augmented request handling.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from mcp.server.experimental.task_context import ServerTaskContext
from mcp.server.experimental.task_support import TaskSupport
from mcp.server.session import ServerSession
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal
from mcp.types import (
METHOD_NOT_FOUND,
TASK_FORBIDDEN,
TASK_REQUIRED,
ClientCapabilities,
CreateTaskResult,
ErrorData,
Result,
TaskExecutionMode,
TaskMetadata,
Tool,
)
@dataclass
class Experimental:
"""
Experimental features context for task-augmented requests.
Provides helpers for validating task execution compatibility and
running tasks with automatic lifecycle management.
WARNING: This API is experimental and may change without notice.
"""
task_metadata: TaskMetadata | None = None
_client_capabilities: ClientCapabilities | None = field(default=None, repr=False)
_session: ServerSession | None = field(default=None, repr=False)
_task_support: TaskSupport | None = field(default=None, repr=False)
@property
def is_task(self) -> bool:
"""Check if this request is task-augmented."""
return self.task_metadata is not None
@property
def client_supports_tasks(self) -> bool:
"""Check if the client declared task support."""
if self._client_capabilities is None:
return False
return self._client_capabilities.tasks is not None
def validate_task_mode(
self,
tool_task_mode: TaskExecutionMode | None,
*,
raise_error: bool = True,
) -> ErrorData | None:
"""
Validate that the request is compatible with the tool's task execution mode.
Per MCP spec:
- "required": Clients MUST invoke as task. Server returns -32601 if not.
- "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do.
- "optional": Either is acceptable.
Args:
tool_task_mode: The tool's execution.taskSupport value
("forbidden", "optional", "required", or None)
raise_error: If True, raises McpError on validation failure. If False, returns ErrorData.
Returns:
None if valid, ErrorData if invalid and raise_error=False
Raises:
McpError: If invalid and raise_error=True
"""
mode = tool_task_mode or TASK_FORBIDDEN
error: ErrorData | None = None
if mode == TASK_REQUIRED and not self.is_task:
error = ErrorData(
code=METHOD_NOT_FOUND,
message="This tool requires task-augmented invocation",
)
elif mode == TASK_FORBIDDEN and self.is_task:
error = ErrorData(
code=METHOD_NOT_FOUND,
message="This tool does not support task-augmented invocation",
)
if error is not None and raise_error:
raise McpError(error)
return error
def validate_for_tool(
self,
tool: Tool,
*,
raise_error: bool = True,
) -> ErrorData | None:
"""
Validate that the request is compatible with the given tool.
Convenience wrapper around validate_task_mode that extracts the mode from a Tool.
Args:
tool: The Tool definition
raise_error: If True, raises McpError on validation failure.
Returns:
None if valid, ErrorData if invalid and raise_error=False
"""
mode = tool.execution.taskSupport if tool.execution else None
return self.validate_task_mode(mode, raise_error=raise_error)
def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool:
"""
Check if this client can use a tool with the given task mode.
Useful for filtering tool lists or providing warnings.
Returns False if tool requires "required" but client doesn't support tasks.
Args:
tool_task_mode: The tool's execution.taskSupport value
Returns:
True if the client can use this tool, False otherwise
"""
mode = tool_task_mode or TASK_FORBIDDEN
if mode == TASK_REQUIRED and not self.client_supports_tasks:
return False
return True
async def run_task(
self,
work: Callable[[ServerTaskContext], Awaitable[Result]],
*,
task_id: str | None = None,
model_immediate_response: str | None = None,
) -> CreateTaskResult:
"""
Create a task, spawn background work, and return CreateTaskResult immediately.
This is the recommended way to handle task-augmented tool calls. It:
1. Creates a task in the store
2. Spawns the work function in a background task
3. Returns CreateTaskResult immediately
The work function receives a ServerTaskContext with:
- elicit() for sending elicitation requests
- create_message() for sampling requests
- update_status() for progress updates
- complete()/fail() for finishing the task
When work() returns a Result, the task is auto-completed with that result.
If work() raises an exception, the task is auto-failed.
Args:
work: Async function that does the actual work
task_id: Optional task ID (generated if not provided)
model_immediate_response: Optional string to include in _meta as
io.modelcontextprotocol/model-immediate-response
Returns:
CreateTaskResult to return to the client
Raises:
RuntimeError: If task support is not enabled or task_metadata is missing
Example:
@server.call_tool()
async def handle_tool(name: str, args: dict):
ctx = server.request_context
async def work(task: ServerTaskContext) -> CallToolResult:
result = await task.elicit(
message="Are you sure?",
requestedSchema={"type": "object", ...}
)
confirmed = result.content.get("confirm", False)
return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")])
return await ctx.experimental.run_task(work)
WARNING: This API is experimental and may change without notice.
"""
if self._task_support is None:
raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.")
if self._session is None:
raise RuntimeError("Session not available.")
if self.task_metadata is None:
raise RuntimeError(
"Request is not task-augmented (no task field in params). "
"The client must send a task-augmented request."
)
support = self._task_support
# Access task_group via TaskSupport - raises if not in run() context
task_group = support.task_group
task = await support.store.create_task(self.task_metadata, task_id)
task_ctx = ServerTaskContext(
task=task,
store=support.store,
session=self._session,
queue=support.queue,
handler=support.handler,
)
async def execute() -> None:
try:
result = await work(task_ctx)
if not is_terminal(task_ctx.task.status):
await task_ctx.complete(result)
except Exception as e:
if not is_terminal(task_ctx.task.status):
await task_ctx.fail(str(e))
task_group.start_soon(execute)
meta: dict[str, Any] | None = None
if model_immediate_response is not None:
meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response}
return CreateTaskResult(task=task, **{"_meta": meta} if meta else {})

View File

@@ -0,0 +1,220 @@
"""
Experimental server session features for server→client task operations.
This module provides the server-side equivalent of ExperimentalClientFeatures,
allowing the server to send task-augmented requests to the client and poll for results.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypeVar
import mcp.types as types
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.experimental.tasks.capabilities import (
require_task_augmented_elicitation,
require_task_augmented_sampling,
)
from mcp.shared.experimental.tasks.polling import poll_until_terminal
if TYPE_CHECKING:
from mcp.server.session import ServerSession
ResultT = TypeVar("ResultT", bound=types.Result)
class ExperimentalServerSessionFeatures:
"""
Experimental server session features for server→client task operations.
This provides the server-side equivalent of ExperimentalClientFeatures,
allowing the server to send task-augmented requests to the client and
poll for results.
WARNING: These APIs are experimental and may change without notice.
Access via session.experimental:
result = await session.experimental.elicit_as_task(...)
"""
def __init__(self, session: "ServerSession") -> None:
self._session = session
async def get_task(self, task_id: str) -> types.GetTaskResult:
"""
Send tasks/get to the client to get task status.
Args:
task_id: The task identifier
Returns:
GetTaskResult containing the task status
"""
return await self._session.send_request(
types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))),
types.GetTaskResult,
)
async def get_task_result(
self,
task_id: str,
result_type: type[ResultT],
) -> ResultT:
"""
Send tasks/result to the client to retrieve the final result.
Args:
task_id: The task identifier
result_type: The expected result type
Returns:
The task result, validated against result_type
"""
return await self._session.send_request(
types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))),
result_type,
)
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
"""
Poll a client task until it reaches terminal status.
Yields GetTaskResult for each poll, allowing the caller to react to
status changes. Exits when task reaches a terminal status.
Respects the pollInterval hint from the client.
Args:
task_id: The task identifier
Yields:
GetTaskResult for each poll
"""
async for status in poll_until_terminal(self.get_task, task_id):
yield status
async def elicit_as_task(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
*,
ttl: int = 60000,
) -> types.ElicitResult:
"""
Send a task-augmented elicitation to the client and poll until complete.
The client will create a local task, process the elicitation asynchronously,
and return the result when ready. This method handles the full flow:
1. Send elicitation with task field
2. Receive CreateTaskResult from client
3. Poll client's task until terminal
4. Retrieve and return the final ElicitResult
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response
ttl: Task time-to-live in milliseconds
Returns:
The client's elicitation response
Raises:
McpError: If client doesn't support task-augmented elicitation
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_elicitation(client_caps)
create_result = await self._session.send_request(
types.ServerRequest(
types.ElicitRequest(
params=types.ElicitRequestFormParams(
message=message,
requestedSchema=requestedSchema,
task=types.TaskMetadata(ttl=ttl),
)
)
),
types.CreateTaskResult,
)
task_id = create_result.task.taskId
async for _ in self.poll_task(task_id):
pass
return await self.get_task_result(task_id, types.ElicitResult)
async def create_message_as_task(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
ttl: int = 60000,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
) -> types.CreateMessageResult:
"""
Send a task-augmented sampling request and poll until complete.
The client will create a local task, process the sampling request
asynchronously, and return the result when ready.
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
ttl: Task time-to-live in milliseconds
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support task-augmented sampling or tools
ValueError: If tool_use or tool_result message structure is invalid
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_sampling(client_caps)
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
create_result = await self._session.send_request(
types.ServerRequest(
types.CreateMessageRequest(
params=types.CreateMessageRequestParams(
messages=messages,
maxTokens=max_tokens,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
tools=tools,
toolChoice=tool_choice,
task=types.TaskMetadata(ttl=ttl),
)
)
),
types.CreateTaskResult,
)
task_id = create_result.task.taskId
async for _ in self.poll_task(task_id):
pass
return await self.get_task_result(task_id, types.CreateMessageResult)

View File

@@ -0,0 +1,612 @@
"""
ServerTaskContext - Server-integrated task context with elicitation and sampling.
This wraps the pure TaskContext and adds server-specific functionality:
- Elicitation (task.elicit())
- Sampling (task.create_message())
- Status notifications
"""
from typing import Any
import anyio
from mcp.server.experimental.task_result_handler import TaskResultHandler
from mcp.server.session import ServerSession
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.capabilities import (
require_task_augmented_elicitation,
require_task_augmented_sampling,
)
from mcp.shared.experimental.tasks.context import TaskContext
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import (
INVALID_REQUEST,
TASK_STATUS_INPUT_REQUIRED,
TASK_STATUS_WORKING,
ClientCapabilities,
CreateMessageResult,
CreateTaskResult,
ElicitationCapability,
ElicitRequestedSchema,
ElicitResult,
ErrorData,
IncludeContext,
ModelPreferences,
RequestId,
Result,
SamplingCapability,
SamplingMessage,
ServerNotification,
Task,
TaskMetadata,
TaskStatusNotification,
TaskStatusNotificationParams,
Tool,
ToolChoice,
)
class ServerTaskContext:
"""
Server-integrated task context with elicitation and sampling.
This wraps a pure TaskContext and adds server-specific functionality:
- elicit() for sending elicitation requests to the client
- create_message() for sampling requests
- Status notifications via the session
Example:
async def my_task_work(task: ServerTaskContext) -> CallToolResult:
await task.update_status("Starting...")
result = await task.elicit(
message="Continue?",
requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}}
)
if result.content.get("ok"):
return CallToolResult(content=[TextContent(text="Done!")])
else:
return CallToolResult(content=[TextContent(text="Cancelled")])
"""
def __init__(
self,
*,
task: Task,
store: TaskStore,
session: ServerSession,
queue: TaskMessageQueue,
handler: TaskResultHandler | None = None,
):
"""
Create a ServerTaskContext.
Args:
task: The Task object
store: The task store
session: The server session
queue: The message queue for elicitation/sampling
handler: The result handler for response routing (required for elicit/create_message)
"""
self._ctx = TaskContext(task=task, store=store)
self._session = session
self._queue = queue
self._handler = handler
self._store = store
# Delegate pure properties to inner context
@property
def task_id(self) -> str:
"""The task identifier."""
return self._ctx.task_id
@property
def task(self) -> Task:
"""The current task state."""
return self._ctx.task
@property
def is_cancelled(self) -> bool:
"""Whether cancellation has been requested."""
return self._ctx.is_cancelled
def request_cancellation(self) -> None:
"""Request cancellation of this task."""
self._ctx.request_cancellation()
# Enhanced methods with notifications
async def update_status(self, message: str, *, notify: bool = True) -> None:
"""
Update the task's status message.
Args:
message: The new status message
notify: Whether to send a notification to the client
"""
await self._ctx.update_status(message)
if notify:
await self._send_notification()
async def complete(self, result: Result, *, notify: bool = True) -> None:
"""
Mark the task as completed with the given result.
Args:
result: The task result
notify: Whether to send a notification to the client
"""
await self._ctx.complete(result)
if notify:
await self._send_notification()
async def fail(self, error: str, *, notify: bool = True) -> None:
"""
Mark the task as failed with an error message.
Args:
error: The error message
notify: Whether to send a notification to the client
"""
await self._ctx.fail(error)
if notify:
await self._send_notification()
async def _send_notification(self) -> None:
"""Send a task status notification to the client."""
task = self._ctx.task
await self._session.send_notification(
ServerNotification(
TaskStatusNotification(
params=TaskStatusNotificationParams(
taskId=task.taskId,
status=task.status,
statusMessage=task.statusMessage,
createdAt=task.createdAt,
lastUpdatedAt=task.lastUpdatedAt,
ttl=task.ttl,
pollInterval=task.pollInterval,
)
)
)
)
# Server-specific methods: elicitation and sampling
def _check_elicitation_capability(self) -> None:
"""Check if the client supports elicitation."""
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support elicitation capability",
)
)
def _check_sampling_capability(self) -> None:
"""Check if the client supports sampling."""
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support sampling capability",
)
)
async def elicit(
self,
message: str,
requestedSchema: ElicitRequestedSchema,
) -> ElicitResult:
"""
Send an elicitation request via the task message queue.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the elicitation request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
Returns:
The client's response
Raises:
McpError: If client doesn't support elicitation capability
"""
self._check_elicitation_capability()
if self._handler is None:
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
requestedSchema=requestedSchema,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
# Coverage can't track async exception handlers reliably.
# This path is tested in test_elicit_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def elicit_url(
self,
message: str,
url: str,
elicitation_id: str,
) -> ElicitResult:
"""
Send a URL mode elicitation request via the task message queue.
This directs the user to an external URL for out-of-band interactions
like OAuth flows, credential collection, or payment processing.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the elicitation request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
Returns:
The client's response indicating acceptance, decline, or cancellation
Raises:
McpError: If client doesn't support elicitation capability
RuntimeError: If handler is not configured
"""
self._check_elicitation_capability()
if self._handler is None:
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
message=message,
url=url,
elicitation_id=elicitation_id,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def create_message(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
tools: list[Tool] | None = None,
tool_choice: ToolChoice | None = None,
) -> CreateMessageResult:
"""
Send a sampling request via the task message queue.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the sampling request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support sampling capability or tools
ValueError: If tool_use or tool_result message structure is invalid
"""
self._check_sampling_capability()
client_caps = self._session.client_params.capabilities if self._session.client_params else None
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
if self._handler is None:
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return CreateMessageResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
# Coverage can't track async exception handlers reliably.
# This path is tested in test_create_message_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def elicit_as_task(
self,
message: str,
requestedSchema: ElicitRequestedSchema,
*,
ttl: int = 60000,
) -> ElicitResult:
"""
Send a task-augmented elicitation via the queue, then poll client.
This is for use inside a task-augmented tool call when you want the client
to handle the elicitation as its own task. The elicitation request is queued
and delivered when the client calls tasks/result. After the client responds
with CreateTaskResult, we poll the client's task until complete.
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
ttl: Task time-to-live in milliseconds for the client's task
Returns:
The client's elicitation response
Raises:
McpError: If client doesn't support task-augmented elicitation
RuntimeError: If handler is not configured
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_elicitation(client_caps)
if self._handler is None:
raise RuntimeError("handler is required for elicit_as_task()")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
requestedSchema=requestedSchema,
related_task_id=self.task_id,
task=TaskMetadata(ttl=ttl),
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for initial response (CreateTaskResult from client)
response_data = await resolver.wait()
create_result = CreateTaskResult.model_validate(response_data)
client_task_id = create_result.task.taskId
# Poll the client's task using session.experimental
async for _ in self._session.experimental.poll_task(client_task_id):
pass
# Get final result from client
result = await self._session.experimental.get_task_result(
client_task_id,
ElicitResult,
)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return result
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def create_message_as_task(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
ttl: int = 60000,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
tools: list[Tool] | None = None,
tool_choice: ToolChoice | None = None,
) -> CreateMessageResult:
"""
Send a task-augmented sampling request via the queue, then poll client.
This is for use inside a task-augmented tool call when you want the client
to handle the sampling as its own task. The request is queued and delivered
when the client calls tasks/result. After the client responds with
CreateTaskResult, we poll the client's task until complete.
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
ttl: Task time-to-live in milliseconds for the client's task
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support task-augmented sampling or tools
ValueError: If tool_use or tool_result message structure is invalid
RuntimeError: If handler is not configured
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_sampling(client_caps)
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
if self._handler is None:
raise RuntimeError("handler is required for create_message_as_task()")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build request WITH task field for task-augmented sampling
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
related_task_id=self.task_id,
task=TaskMetadata(ttl=ttl),
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for initial response (CreateTaskResult from client)
response_data = await resolver.wait()
create_result = CreateTaskResult.model_validate(response_data)
client_task_id = create_result.task.taskId
# Poll the client's task using session.experimental
async for _ in self._session.experimental.poll_task(client_task_id):
pass
# Get final result from client
result = await self._session.experimental.get_task_result(
client_task_id,
CreateMessageResult,
)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return result
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise

View File

@@ -0,0 +1,235 @@
"""
TaskResultHandler - Integrated handler for tasks/result endpoint.
This implements the dequeue-send-wait pattern from the MCP Tasks spec:
1. Dequeue all pending messages for the task
2. Send them to the client via transport with relatedRequestId routing
3. Wait if task is not in terminal state
4. Return final result when task completes
This is the core of the task message queue pattern.
"""
import logging
from typing import Any
import anyio
from mcp.server.session import ServerSession
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import (
INVALID_PARAMS,
ErrorData,
GetTaskPayloadRequest,
GetTaskPayloadResult,
JSONRPCMessage,
RelatedTaskMetadata,
RequestId,
)
logger = logging.getLogger(__name__)
class TaskResultHandler:
"""
Handler for tasks/result that implements the message queue pattern.
This handler:
1. Dequeues pending messages (elicitations, notifications) for the task
2. Sends them to the client via the response stream
3. Waits for responses and resolves them back to callers
4. Blocks until task reaches terminal state
5. Returns the final result
Usage:
# Create handler with store and queue
handler = TaskResultHandler(task_store, message_queue)
# Register it with the server
@server.experimental.get_task_result()
async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult:
ctx = server.request_context
return await handler.handle(req, ctx.session, ctx.request_id)
# Or use the convenience method
handler.register(server)
"""
def __init__(
self,
store: TaskStore,
queue: TaskMessageQueue,
):
self._store = store
self._queue = queue
# Map from internal request ID to resolver for routing responses
self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {}
async def send_message(
self,
session: ServerSession,
message: SessionMessage,
) -> None:
"""
Send a message via the session.
This is a helper for delivering queued task messages.
"""
await session.send_message(message)
async def handle(
self,
request: GetTaskPayloadRequest,
session: ServerSession,
request_id: RequestId,
) -> GetTaskPayloadResult:
"""
Handle a tasks/result request.
This implements the dequeue-send-wait loop:
1. Dequeue all pending messages
2. Send each via transport with relatedRequestId = this request's ID
3. If task not terminal, wait for status change
4. Loop until task is terminal
5. Return final result
Args:
request: The GetTaskPayloadRequest
session: The server session for sending messages
request_id: The request ID for relatedRequestId routing
Returns:
GetTaskPayloadResult with the task's final payload
"""
task_id = request.params.taskId
while True:
task = await self._store.get_task(task_id)
if task is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task not found: {task_id}",
)
)
await self._deliver_queued_messages(task_id, session, request_id)
# If task is terminal, return result
if is_terminal(task.status):
result = await self._store.get_result(task_id)
# GetTaskPayloadResult is a Result with extra="allow"
# The stored result contains the actual payload data
# Per spec: tasks/result MUST include _meta with related-task metadata
related_task = RelatedTaskMetadata(taskId=task_id)
related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)}
if result is not None:
result_data = result.model_dump(by_alias=True)
existing_meta: dict[str, Any] = result_data.get("_meta") or {}
result_data["_meta"] = {**existing_meta, **related_task_meta}
return GetTaskPayloadResult.model_validate(result_data)
return GetTaskPayloadResult.model_validate({"_meta": related_task_meta})
# Wait for task update (status change or new messages)
await self._wait_for_task_update(task_id)
async def _deliver_queued_messages(
self,
task_id: str,
session: ServerSession,
request_id: RequestId,
) -> None:
"""
Dequeue and send all pending messages for a task.
Each message is sent via the session's write stream with
relatedRequestId set so responses route back to this stream.
"""
while True:
message = await self._queue.dequeue(task_id)
if message is None:
break
# If this is a request (not notification), wait for response
if message.type == "request" and message.resolver is not None:
# Store the resolver so we can route the response back
original_id = message.original_request_id
if original_id is not None:
self._pending_requests[original_id] = message.resolver
logger.debug("Delivering queued message for task %s: %s", task_id, message.type)
# Send the message with relatedRequestId for routing
session_message = SessionMessage(
message=JSONRPCMessage(message.message),
metadata=ServerMessageMetadata(related_request_id=request_id),
)
await self.send_message(session, session_message)
async def _wait_for_task_update(self, task_id: str) -> None:
"""
Wait for task to be updated (status change or new message).
Races between store update and queue message - first one wins.
"""
async with anyio.create_task_group() as tg:
async def wait_for_store() -> None:
try:
await self._store.wait_for_update(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()
async def wait_for_queue() -> None:
try:
await self._queue.wait_for_message(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()
tg.start_soon(wait_for_store)
tg.start_soon(wait_for_queue)
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
"""
Route a response back to the waiting resolver.
This is called when a response arrives for a queued request.
Args:
request_id: The request ID from the response
response: The response data
Returns:
True if response was routed, False if no pending request
"""
resolver = self._pending_requests.pop(request_id, None)
if resolver is not None and not resolver.done():
resolver.set_result(response)
return True
return False
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
"""
Route an error back to the waiting resolver.
Args:
request_id: The request ID from the error response
error: The error data
Returns:
True if error was routed, False if no pending request
"""
resolver = self._pending_requests.pop(request_id, None)
if resolver is not None and not resolver.done():
resolver.set_exception(McpError(error))
return True
return False

View File

@@ -0,0 +1,115 @@
"""
TaskSupport - Configuration for experimental task support.
This module provides the TaskSupport class which encapsulates all the
infrastructure needed for task-augmented requests: store, queue, and handler.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
import anyio
from anyio.abc import TaskGroup
from mcp.server.experimental.task_result_handler import TaskResultHandler
from mcp.server.session import ServerSession
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
from mcp.shared.experimental.tasks.store import TaskStore
@dataclass
class TaskSupport:
"""
Configuration for experimental task support.
Encapsulates the task store, message queue, result handler, and task group
for spawning background work.
When enabled on a server, this automatically:
- Configures response routing for each session
- Provides default handlers for task operations
- Manages a task group for background task execution
Example:
# Simple in-memory setup
server.experimental.enable_tasks()
# Custom store/queue for distributed systems
server.experimental.enable_tasks(
store=RedisTaskStore(redis_url),
queue=RedisTaskMessageQueue(redis_url),
)
"""
store: TaskStore
queue: TaskMessageQueue
handler: TaskResultHandler = field(init=False)
_task_group: TaskGroup | None = field(init=False, default=None)
def __post_init__(self) -> None:
"""Create the result handler from store and queue."""
self.handler = TaskResultHandler(self.store, self.queue)
@property
def task_group(self) -> TaskGroup:
"""Get the task group for spawning background work.
Raises:
RuntimeError: If not within a run() context
"""
if self._task_group is None:
raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.")
return self._task_group
@asynccontextmanager
async def run(self) -> AsyncIterator[None]:
"""
Run the task support lifecycle.
This creates a task group for spawning background task work.
Called automatically by Server.run().
Usage:
async with task_support.run():
# Task group is now available
...
"""
async with anyio.create_task_group() as tg:
self._task_group = tg
try:
yield
finally:
self._task_group = None
def configure_session(self, session: ServerSession) -> None:
"""
Configure a session for task support.
This registers the result handler as a response router so that
responses to queued requests (elicitation, sampling) are routed
back to the waiting resolvers.
Called automatically by Server.run() for each new session.
Args:
session: The session to configure
"""
session.add_response_router(self.handler)
@classmethod
def in_memory(cls) -> "TaskSupport":
"""
Create in-memory task support.
Suitable for development, testing, and single-process servers.
For distributed systems, provide custom store and queue implementations.
Returns:
TaskSupport configured with in-memory store and queue
"""
return cls(
store=InMemoryTaskStore(),
queue=InMemoryTaskMessageQueue(),
)