Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,612 @@
|
||||
"""
|
||||
ServerTaskContext - Server-integrated task context with elicitation and sampling.
|
||||
|
||||
This wraps the pure TaskContext and adds server-specific functionality:
|
||||
- Elicitation (task.elicit())
|
||||
- Sampling (task.create_message())
|
||||
- Status notifications
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.experimental.task_result_handler import TaskResultHandler
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.experimental.tasks.capabilities import (
|
||||
require_task_augmented_elicitation,
|
||||
require_task_augmented_sampling,
|
||||
)
|
||||
from mcp.shared.experimental.tasks.context import TaskContext
|
||||
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
|
||||
from mcp.shared.experimental.tasks.resolver import Resolver
|
||||
from mcp.shared.experimental.tasks.store import TaskStore
|
||||
from mcp.types import (
|
||||
INVALID_REQUEST,
|
||||
TASK_STATUS_INPUT_REQUIRED,
|
||||
TASK_STATUS_WORKING,
|
||||
ClientCapabilities,
|
||||
CreateMessageResult,
|
||||
CreateTaskResult,
|
||||
ElicitationCapability,
|
||||
ElicitRequestedSchema,
|
||||
ElicitResult,
|
||||
ErrorData,
|
||||
IncludeContext,
|
||||
ModelPreferences,
|
||||
RequestId,
|
||||
Result,
|
||||
SamplingCapability,
|
||||
SamplingMessage,
|
||||
ServerNotification,
|
||||
Task,
|
||||
TaskMetadata,
|
||||
TaskStatusNotification,
|
||||
TaskStatusNotificationParams,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
)
|
||||
|
||||
|
||||
class ServerTaskContext:
|
||||
"""
|
||||
Server-integrated task context with elicitation and sampling.
|
||||
|
||||
This wraps a pure TaskContext and adds server-specific functionality:
|
||||
- elicit() for sending elicitation requests to the client
|
||||
- create_message() for sampling requests
|
||||
- Status notifications via the session
|
||||
|
||||
Example:
|
||||
async def my_task_work(task: ServerTaskContext) -> CallToolResult:
|
||||
await task.update_status("Starting...")
|
||||
|
||||
result = await task.elicit(
|
||||
message="Continue?",
|
||||
requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}}
|
||||
)
|
||||
|
||||
if result.content.get("ok"):
|
||||
return CallToolResult(content=[TextContent(text="Done!")])
|
||||
else:
|
||||
return CallToolResult(content=[TextContent(text="Cancelled")])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: Task,
|
||||
store: TaskStore,
|
||||
session: ServerSession,
|
||||
queue: TaskMessageQueue,
|
||||
handler: TaskResultHandler | None = None,
|
||||
):
|
||||
"""
|
||||
Create a ServerTaskContext.
|
||||
|
||||
Args:
|
||||
task: The Task object
|
||||
store: The task store
|
||||
session: The server session
|
||||
queue: The message queue for elicitation/sampling
|
||||
handler: The result handler for response routing (required for elicit/create_message)
|
||||
"""
|
||||
self._ctx = TaskContext(task=task, store=store)
|
||||
self._session = session
|
||||
self._queue = queue
|
||||
self._handler = handler
|
||||
self._store = store
|
||||
|
||||
# Delegate pure properties to inner context
|
||||
|
||||
@property
|
||||
def task_id(self) -> str:
|
||||
"""The task identifier."""
|
||||
return self._ctx.task_id
|
||||
|
||||
@property
|
||||
def task(self) -> Task:
|
||||
"""The current task state."""
|
||||
return self._ctx.task
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Whether cancellation has been requested."""
|
||||
return self._ctx.is_cancelled
|
||||
|
||||
def request_cancellation(self) -> None:
|
||||
"""Request cancellation of this task."""
|
||||
self._ctx.request_cancellation()
|
||||
|
||||
# Enhanced methods with notifications
|
||||
|
||||
async def update_status(self, message: str, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Update the task's status message.
|
||||
|
||||
Args:
|
||||
message: The new status message
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.update_status(message)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def complete(self, result: Result, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Mark the task as completed with the given result.
|
||||
|
||||
Args:
|
||||
result: The task result
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.complete(result)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def fail(self, error: str, *, notify: bool = True) -> None:
|
||||
"""
|
||||
Mark the task as failed with an error message.
|
||||
|
||||
Args:
|
||||
error: The error message
|
||||
notify: Whether to send a notification to the client
|
||||
"""
|
||||
await self._ctx.fail(error)
|
||||
if notify:
|
||||
await self._send_notification()
|
||||
|
||||
async def _send_notification(self) -> None:
|
||||
"""Send a task status notification to the client."""
|
||||
task = self._ctx.task
|
||||
await self._session.send_notification(
|
||||
ServerNotification(
|
||||
TaskStatusNotification(
|
||||
params=TaskStatusNotificationParams(
|
||||
taskId=task.taskId,
|
||||
status=task.status,
|
||||
statusMessage=task.statusMessage,
|
||||
createdAt=task.createdAt,
|
||||
lastUpdatedAt=task.lastUpdatedAt,
|
||||
ttl=task.ttl,
|
||||
pollInterval=task.pollInterval,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Server-specific methods: elicitation and sampling
|
||||
|
||||
def _check_elicitation_capability(self) -> None:
|
||||
"""Check if the client supports elicitation."""
|
||||
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support elicitation capability",
|
||||
)
|
||||
)
|
||||
|
||||
def _check_sampling_capability(self) -> None:
|
||||
"""Check if the client supports sampling."""
|
||||
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
|
||||
raise McpError(
|
||||
ErrorData(
|
||||
code=INVALID_REQUEST,
|
||||
message="Client does not support sampling capability",
|
||||
)
|
||||
)
|
||||
|
||||
async def elicit(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: ElicitRequestedSchema,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send an elicitation request via the task message queue.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the elicitation request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
|
||||
Returns:
|
||||
The client's response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support elicitation capability
|
||||
"""
|
||||
self._check_elicitation_capability()
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return ElicitResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
# Coverage can't track async exception handlers reliably.
|
||||
# This path is tested in test_elicit_restores_status_on_cancellation
|
||||
# which verifies status is restored to "working" after cancellation.
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def elicit_url(
|
||||
self,
|
||||
message: str,
|
||||
url: str,
|
||||
elicitation_id: str,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send a URL mode elicitation request via the task message queue.
|
||||
|
||||
This directs the user to an external URL for out-of-band interactions
|
||||
like OAuth flows, credential collection, or payment processing.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the elicitation request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
message: Human-readable explanation of why the interaction is needed
|
||||
url: The URL the user should navigate to
|
||||
elicitation_id: Unique identifier for tracking this elicitation
|
||||
|
||||
Returns:
|
||||
The client's response indicating acceptance, decline, or cancellation
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support elicitation capability
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
self._check_elicitation_capability()
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
url=url,
|
||||
elicitation_id=elicitation_id,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return ElicitResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""
|
||||
Send a sampling request via the task message queue.
|
||||
|
||||
This method:
|
||||
1. Checks client capability
|
||||
2. Updates task status to "input_required"
|
||||
3. Queues the sampling request
|
||||
4. Waits for the response (delivered via tasks/result round-trip)
|
||||
5. Updates task status back to "working"
|
||||
6. Returns the result
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support sampling capability or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
"""
|
||||
self._check_sampling_capability()
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build the request using session's helper
|
||||
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
include_context=include_context,
|
||||
temperature=temperature,
|
||||
stop_sequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
model_preferences=model_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
related_task_id=self.task_id,
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for response (routed back via TaskResultHandler)
|
||||
response_data = await resolver.wait()
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return CreateMessageResult.model_validate(response_data)
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
# Coverage can't track async exception handlers reliably.
|
||||
# This path is tested in test_create_message_restores_status_on_cancellation
|
||||
# which verifies status is restored to "working" after cancellation.
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def elicit_as_task(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: ElicitRequestedSchema,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
) -> ElicitResult:
|
||||
"""
|
||||
Send a task-augmented elicitation via the queue, then poll client.
|
||||
|
||||
This is for use inside a task-augmented tool call when you want the client
|
||||
to handle the elicitation as its own task. The elicitation request is queued
|
||||
and delivered when the client calls tasks/result. After the client responds
|
||||
with CreateTaskResult, we poll the client's task until complete.
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
ttl: Task time-to-live in milliseconds for the client's task
|
||||
|
||||
Returns:
|
||||
The client's elicitation response
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented elicitation
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_elicitation(client_caps)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for elicit_as_task()")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
related_task_id=self.task_id,
|
||||
task=TaskMetadata(ttl=ttl),
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for initial response (CreateTaskResult from client)
|
||||
response_data = await resolver.wait()
|
||||
create_result = CreateTaskResult.model_validate(response_data)
|
||||
client_task_id = create_result.task.taskId
|
||||
|
||||
# Poll the client's task using session.experimental
|
||||
async for _ in self._session.experimental.poll_task(client_task_id):
|
||||
pass
|
||||
|
||||
# Get final result from client
|
||||
result = await self._session.experimental.get_task_result(
|
||||
client_task_id,
|
||||
ElicitResult,
|
||||
)
|
||||
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return result
|
||||
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
|
||||
async def create_message_as_task(
|
||||
self,
|
||||
messages: list[SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
ttl: int = 60000,
|
||||
system_prompt: str | None = None,
|
||||
include_context: IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: ModelPreferences | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
tool_choice: ToolChoice | None = None,
|
||||
) -> CreateMessageResult:
|
||||
"""
|
||||
Send a task-augmented sampling request via the queue, then poll client.
|
||||
|
||||
This is for use inside a task-augmented tool call when you want the client
|
||||
to handle the sampling as its own task. The request is queued and delivered
|
||||
when the client calls tasks/result. After the client responds with
|
||||
CreateTaskResult, we poll the client's task until complete.
|
||||
|
||||
Args:
|
||||
messages: The conversation messages for sampling
|
||||
max_tokens: Maximum tokens in the response
|
||||
ttl: Task time-to-live in milliseconds for the client's task
|
||||
system_prompt: Optional system prompt
|
||||
include_context: Context inclusion strategy
|
||||
temperature: Sampling temperature
|
||||
stop_sequences: Stop sequences
|
||||
metadata: Additional metadata
|
||||
model_preferences: Model selection preferences
|
||||
tools: Optional list of tools the LLM can use during sampling
|
||||
tool_choice: Optional control over tool usage behavior
|
||||
|
||||
Returns:
|
||||
The sampling result from the client
|
||||
|
||||
Raises:
|
||||
McpError: If client doesn't support task-augmented sampling or tools
|
||||
ValueError: If tool_use or tool_result message structure is invalid
|
||||
RuntimeError: If handler is not configured
|
||||
"""
|
||||
client_caps = self._session.client_params.capabilities if self._session.client_params else None
|
||||
require_task_augmented_sampling(client_caps)
|
||||
validate_sampling_tools(client_caps, tools, tool_choice)
|
||||
validate_tool_use_result_messages(messages)
|
||||
|
||||
if self._handler is None:
|
||||
raise RuntimeError("handler is required for create_message_as_task()")
|
||||
|
||||
# Update status to input_required
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
|
||||
|
||||
# Build request WITH task field for task-augmented sampling
|
||||
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
include_context=include_context,
|
||||
temperature=temperature,
|
||||
stop_sequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
model_preferences=model_preferences,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
related_task_id=self.task_id,
|
||||
task=TaskMetadata(ttl=ttl),
|
||||
)
|
||||
request_id: RequestId = request.id
|
||||
|
||||
resolver: Resolver[dict[str, Any]] = Resolver()
|
||||
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
queued = QueuedMessage(
|
||||
type="request",
|
||||
message=request,
|
||||
resolver=resolver,
|
||||
original_request_id=request_id,
|
||||
)
|
||||
await self._queue.enqueue(self.task_id, queued)
|
||||
|
||||
try:
|
||||
# Wait for initial response (CreateTaskResult from client)
|
||||
response_data = await resolver.wait()
|
||||
create_result = CreateTaskResult.model_validate(response_data)
|
||||
client_task_id = create_result.task.taskId
|
||||
|
||||
# Poll the client's task using session.experimental
|
||||
async for _ in self._session.experimental.poll_task(client_task_id):
|
||||
pass
|
||||
|
||||
# Get final result from client
|
||||
result = await self._session.experimental.get_task_result(
|
||||
client_task_id,
|
||||
CreateMessageResult,
|
||||
)
|
||||
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
return result
|
||||
|
||||
except anyio.get_cancelled_exc_class(): # pragma: no cover
|
||||
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
|
||||
raise
|
||||
Reference in New Issue
Block a user