Fix project isolation: Make loadChatHistory respect active project sessions

- Modified loadChatHistory() to check for active project before fetching all sessions
- When active project exists, use project.sessions instead of fetching from API
- Added detailed console logging to debug session filtering
- This prevents ALL sessions from appearing in every project's sidebar

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
uroma
2026-01-22 14:43:05 +00:00
Unverified
parent b82837aa5f
commit 55aafbae9a
6463 changed files with 1115462 additions and 4486 deletions

View File

@@ -0,0 +1,5 @@
from .fastmcp import FastMCP
from .lowlevel import NotificationOptions, Server
from .models import InitializationOptions
__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"]

View File

@@ -0,0 +1,50 @@
import importlib.metadata
import logging
import sys
import anyio
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.server.stdio import stdio_server
from mcp.types import ServerCapabilities
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("server")
async def receive_loop(session: ServerSession):
logger.info("Starting receive loop")
async for message in session.incoming_messages:
if isinstance(message, Exception):
logger.error("Error: %s", message)
continue
logger.info("Received message from client: %s", message)
async def main():
version = importlib.metadata.version("mcp")
async with stdio_server() as (read_stream, write_stream):
async with (
ServerSession(
read_stream,
write_stream,
InitializationOptions(
server_name="mcp",
server_version=version,
capabilities=ServerCapabilities(),
),
) as session,
write_stream,
):
await receive_loop(session)
if __name__ == "__main__":
anyio.run(main, backend="trio")

View File

@@ -0,0 +1,3 @@
"""
MCP OAuth server authorization components.
"""

View File

@@ -0,0 +1,5 @@
from pydantic import ValidationError
def stringify_pydantic_error(validation_error: ValidationError) -> str:
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())

View File

@@ -0,0 +1,3 @@
"""
Request handlers for MCP authorization endpoints.
"""

View File

@@ -0,0 +1,224 @@
import logging
from dataclasses import dataclass
from typing import Any, Literal
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request
from starlette.responses import RedirectResponse, Response
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import (
AuthorizationErrorCode,
AuthorizationParams,
AuthorizeError,
OAuthAuthorizationServerProvider,
construct_redirect_uri,
)
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
logger = logging.getLogger(__name__)
class AuthorizationRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
client_id: str = Field(..., description="The client ID")
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
# see OAuthClientMetadata; we only support `code`
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
code_challenge: str = Field(..., description="PKCE code challenge")
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
state: str | None = Field(None, description="Optional state parameter")
scope: str | None = Field(
None,
description="Optional scope; if specified, should be a space-separated list of scope strings",
)
resource: str | None = Field(
None,
description="RFC 8707 resource indicator - the MCP server this token will be used with",
)
class AuthorizationErrorResponse(BaseModel):
error: AuthorizationErrorCode
error_description: str | None
error_uri: AnyUrl | None = None
# must be set if provided in the request
state: str | None = None
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
if params is None: # pragma: no cover
return None
value = params.get(key)
if isinstance(value, str):
return value
return None
class AnyUrlModel(RootModel[AnyUrl]):
root: AnyUrl
@dataclass
class AuthorizationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
async def handle(self, request: Request) -> Response:
# implements authorization requests for grant_type=code;
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
state = None
redirect_uri = None
client = None
params = None
async def error_response(
error: AuthorizationErrorCode,
error_description: str | None,
attempt_load_client: bool = True,
):
# Error responses take two different formats:
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
# back to the redirect_uri with the error response fields as query
# parameters. This allows the client to be notified of the error.
# 2. Otherwise, we return an error response directly to the end user;
# we choose to do so in JSON, but this is left undefined in the
# specification.
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
#
# This logic is a bit awkward to handle, because the error might be thrown
# very early in request validation, before we've done the usual Pydantic
# validation, loaded the client, etc. To handle this, error_response()
# contains fallback logic which attempts to load the parameters directly
# from the request.
nonlocal client, redirect_uri, state
if client is None and attempt_load_client:
# make last-ditch attempt to load the client
client_id = best_effort_extract_string("client_id", params)
client = await self.provider.get_client(client_id) if client_id else None
if redirect_uri is None and client:
# make last-ditch effort to load the redirect uri
try:
if params is not None and "redirect_uri" not in params:
raw_redirect_uri = None
else:
raw_redirect_uri = AnyUrlModel.model_validate(
best_effort_extract_string("redirect_uri", params)
).root
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
except (ValidationError, InvalidRedirectUriError):
# if the redirect URI is invalid, ignore it & just return the
# initial error
pass
# the error response MUST contain the state specified by the client, if any
if state is None: # pragma: no cover
# make last-ditch effort to load state
state = best_effort_extract_string("state", params)
error_resp = AuthorizationErrorResponse(
error=error,
error_description=error_description,
state=state,
)
if redirect_uri and client:
return RedirectResponse(
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
status_code=302,
headers={"Cache-Control": "no-store"},
)
else:
return PydanticJSONResponse(
status_code=400,
content=error_resp,
headers={"Cache-Control": "no-store"},
)
try:
# Parse request parameters
if request.method == "GET":
# Convert query_params to dict for pydantic validation
params = request.query_params
else:
# Parse form data for POST requests
params = await request.form()
# Save state if it exists, even before validation
state = best_effort_extract_string("state", params)
try:
auth_request = AuthorizationRequest.model_validate(params)
state = auth_request.state # Update with validated state
except ValidationError as validation_error:
error: AuthorizationErrorCode = "invalid_request"
for e in validation_error.errors():
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
error = "unsupported_response_type"
break
return await error_response(error, stringify_pydantic_error(validation_error))
# Get client information
client = await self.provider.get_client(
auth_request.client_id,
)
if not client:
# For client_id validation errors, return direct error (no redirect)
return await error_response(
error="invalid_request",
error_description=f"Client ID '{auth_request.client_id}' not found",
attempt_load_client=False,
)
# Validate redirect_uri against client's registered URIs
try:
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
except InvalidRedirectUriError as validation_error:
# For redirect_uri validation errors, return direct error (no redirect)
return await error_response(
error="invalid_request",
error_description=validation_error.message,
)
# Validate scope - for scope errors, we can redirect
try:
scopes = client.validate_scope(auth_request.scope)
except InvalidScopeError as validation_error:
# For scope errors, redirect with error parameters
return await error_response(
error="invalid_scope",
error_description=validation_error.message,
)
# Setup authorization parameters
auth_params = AuthorizationParams(
state=state,
scopes=scopes,
code_challenge=auth_request.code_challenge,
redirect_uri=redirect_uri,
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
resource=auth_request.resource, # RFC 8707
)
try:
# Let the provider pick the next URI to redirect to
return RedirectResponse(
url=await self.provider.authorize(
client,
auth_params,
),
status_code=302,
headers={"Cache-Control": "no-store"},
)
except AuthorizeError as e:
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
return await error_response(error=e.error, error_description=e.error_description)
except Exception as validation_error: # pragma: no cover
# Catch-all for unexpected errors
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
return await error_response(error="server_error", error_description="An unexpected error occurred")

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
@dataclass
class MetadataHandler:
metadata: OAuthMetadata
async def handle(self, request: Request) -> Response:
return PydanticJSONResponse(
content=self.metadata,
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
)
@dataclass
class ProtectedResourceMetadataHandler:
metadata: ProtectedResourceMetadata
async def handle(self, request: Request) -> Response:
return PydanticJSONResponse(
content=self.metadata,
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
)

View File

@@ -0,0 +1,136 @@
import secrets
import time
from dataclasses import dataclass
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, RootModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
from mcp.server.auth.settings import ClientRegistrationOptions
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
class RegistrationRequest(RootModel[OAuthClientMetadata]):
# this wrapper is a no-op; it's just to separate out the types exposed to the
# provider from what we use in the HTTP handler
root: OAuthClientMetadata
class RegistrationErrorResponse(BaseModel):
error: RegistrationErrorCode
error_description: str | None
@dataclass
class RegistrationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
options: ClientRegistrationOptions
async def handle(self, request: Request) -> Response:
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
try:
# Parse request body as JSON
body = await request.json()
client_metadata = OAuthClientMetadata.model_validate(body)
# Scope validation is handled below
except ValidationError as validation_error:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description=stringify_pydantic_error(validation_error),
),
status_code=400,
)
client_id = str(uuid4())
# If auth method is None, default to client_secret_post
if client_metadata.token_endpoint_auth_method is None:
client_metadata.token_endpoint_auth_method = "client_secret_post"
client_secret = None
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
# cryptographically secure random 32-byte hex string
client_secret = secrets.token_hex(32)
if client_metadata.scope is None and self.options.default_scopes is not None:
client_metadata.scope = " ".join(self.options.default_scopes)
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
requested_scopes = set(client_metadata.scope.split())
valid_scopes = set(self.options.valid_scopes)
if not requested_scopes.issubset(valid_scopes): # pragma: no branch
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="Requested scopes are not valid: "
f"{', '.join(requested_scopes - valid_scopes)}",
),
status_code=400,
)
if not {"authorization_code", "refresh_token"}.issubset(set(client_metadata.grant_types)):
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="grant_types must be authorization_code and refresh_token",
),
status_code=400,
)
# The MCP spec requires servers to use the authorization `code` flow
# with PKCE
if "code" not in client_metadata.response_types:
return PydanticJSONResponse(
content=RegistrationErrorResponse(
error="invalid_client_metadata",
error_description="response_types must include 'code' for authorization_code grant",
),
status_code=400,
)
client_id_issued_at = int(time.time())
client_secret_expires_at = (
client_id_issued_at + self.options.client_secret_expiry_seconds
if self.options.client_secret_expiry_seconds is not None
else None
)
client_info = OAuthClientInformationFull(
client_id=client_id,
client_id_issued_at=client_id_issued_at,
client_secret=client_secret,
client_secret_expires_at=client_secret_expires_at,
# passthrough information from the client request
redirect_uris=client_metadata.redirect_uris,
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
grant_types=client_metadata.grant_types,
response_types=client_metadata.response_types,
client_name=client_metadata.client_name,
client_uri=client_metadata.client_uri,
logo_uri=client_metadata.logo_uri,
scope=client_metadata.scope,
contacts=client_metadata.contacts,
tos_uri=client_metadata.tos_uri,
policy_uri=client_metadata.policy_uri,
jwks_uri=client_metadata.jwks_uri,
jwks=client_metadata.jwks,
software_id=client_metadata.software_id,
software_version=client_metadata.software_version,
)
try:
# Register client
await self.provider.register_client(client_info)
# Return client information
return PydanticJSONResponse(content=client_info, status_code=201)
except RegistrationError as e:
# Handle registration errors as defined in RFC 7591 Section 3.2.2
return PydanticJSONResponse(
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
status_code=400,
)

View File

@@ -0,0 +1,91 @@
from dataclasses import dataclass
from functools import partial
from typing import Any, Literal
from pydantic import BaseModel, ValidationError
from starlette.requests import Request
from starlette.responses import Response
from mcp.server.auth.errors import (
stringify_pydantic_error,
)
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
class RevocationRequest(BaseModel):
"""
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
"""
token: str
token_type_hint: Literal["access_token", "refresh_token"] | None = None
client_id: str
client_secret: str | None
class RevocationErrorResponse(BaseModel):
error: Literal["invalid_request", "unauthorized_client"]
error_description: str | None = None
@dataclass
class RevocationHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
client_authenticator: ClientAuthenticator
async def handle(self, request: Request) -> Response:
"""
Handler for the OAuth 2.0 Token Revocation endpoint.
"""
try:
client = await self.client_authenticator.authenticate_request(request)
except AuthenticationError as e: # pragma: no cover
return PydanticJSONResponse(
status_code=401,
content=RevocationErrorResponse(
error="unauthorized_client",
error_description=e.message,
),
)
try:
form_data = await request.form()
revocation_request = RevocationRequest.model_validate(dict(form_data))
except ValidationError as e:
return PydanticJSONResponse(
status_code=400,
content=RevocationErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(e),
),
)
loaders = [
self.provider.load_access_token,
partial(self.provider.load_refresh_token, client),
]
if revocation_request.token_type_hint == "refresh_token": # pragma: no cover
loaders = reversed(loaders)
token: None | AccessToken | RefreshToken = None
for loader in loaders:
token = await loader(revocation_request.token)
if token is not None:
break
# if token is not found, just return HTTP 200 per the RFC
if token and token.client_id == client.client_id:
# Revoke token; provider is not meant to be able to do validation
# at this point that would result in an error
await self.provider.revoke_token(token)
# Return successful empty response
return Response(
status_code=200,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)

View File

@@ -0,0 +1,241 @@
import base64
import hashlib
import time
from dataclasses import dataclass
from typing import Annotated, Any, Literal
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.requests import Request
from mcp.server.auth.errors import stringify_pydantic_error
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
from mcp.shared.auth import OAuthToken
class AuthorizationCodeRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
grant_type: Literal["authorization_code"]
code: str = Field(..., description="The authorization code")
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
code_verifier: str = Field(..., description="PKCE code verifier")
# RFC 8707 resource indicator
resource: str | None = Field(None, description="Resource indicator for the token")
class RefreshTokenRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
grant_type: Literal["refresh_token"]
refresh_token: str = Field(..., description="The refresh token")
scope: str | None = Field(None, description="Optional scope parameter")
client_id: str
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
client_secret: str | None = None
# RFC 8707 resource indicator
resource: str | None = Field(None, description="Resource indicator for the token")
class TokenRequest(
RootModel[
Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
Field(discriminator="grant_type"),
]
]
):
root: Annotated[
AuthorizationCodeRequest | RefreshTokenRequest,
Field(discriminator="grant_type"),
]
class TokenErrorResponse(BaseModel):
"""
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
"""
error: TokenErrorCode
error_description: str | None = None
error_uri: AnyHttpUrl | None = None
class TokenSuccessResponse(RootModel[OAuthToken]):
# this is just a wrapper over OAuthToken; the only reason we do this
# is to have some separation between the HTTP response type, and the
# type returned by the provider
root: OAuthToken
@dataclass
class TokenHandler:
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
client_authenticator: ClientAuthenticator
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
status_code = 200
if isinstance(obj, TokenErrorResponse):
status_code = 400
return PydanticJSONResponse(
content=obj,
status_code=status_code,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)
async def handle(self, request: Request):
try:
client_info = await self.client_authenticator.authenticate_request(request)
except AuthenticationError as e:
# Authentication failures should return 401
return PydanticJSONResponse(
content=TokenErrorResponse(
error="unauthorized_client",
error_description=e.message,
),
status_code=401,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)
try:
form_data = await request.form()
token_request = TokenRequest.model_validate(dict(form_data)).root
except ValidationError as validation_error: # pragma: no cover
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(validation_error),
)
)
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
return self.response(
TokenErrorResponse(
error="unsupported_grant_type",
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
)
)
tokens: OAuthToken
match token_request:
case AuthorizationCodeRequest():
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
if auth_code is None or auth_code.client_id != token_request.client_id:
# if code belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code does not exist",
)
)
# make auth codes expire after a deadline
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
if auth_code.expires_at < time.time():
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="authorization code has expired",
)
)
# verify redirect_uri doesn't change between /authorize and /tokens
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
if auth_code.redirect_uri_provided_explicitly:
authorize_request_redirect_uri = auth_code.redirect_uri
else: # pragma: no cover
authorize_request_redirect_uri = None
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
auth_redirect_str = (
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
)
if token_redirect_str != auth_redirect_str:
return self.response(
TokenErrorResponse(
error="invalid_request",
error_description=("redirect_uri did not match the one used when creating auth code"),
)
)
# Verify PKCE code verifier
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
if hashed_code_verifier != auth_code.code_challenge:
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="incorrect code_verifier",
)
)
try:
# Exchange authorization code for tokens
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
case RefreshTokenRequest(): # pragma: no cover
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
if refresh_token is None or refresh_token.client_id != token_request.client_id:
# if token belongs to different client, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token does not exist",
)
)
if refresh_token.expires_at and refresh_token.expires_at < time.time():
# if the refresh token has expired, pretend it doesn't exist
return self.response(
TokenErrorResponse(
error="invalid_grant",
error_description="refresh token has expired",
)
)
# Parse scopes if provided
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
for scope in scopes:
if scope not in refresh_token.scopes:
return self.response(
TokenErrorResponse(
error="invalid_scope",
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
)
)
try:
# Exchange refresh token for new tokens
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
except TokenError as e:
return self.response(
TokenErrorResponse(
error=e.error,
error_description=e.error_description,
)
)
return self.response(TokenSuccessResponse(root=tokens))

View File

@@ -0,0 +1,10 @@
from typing import Any
from starlette.responses import JSONResponse
class PydanticJSONResponse(JSONResponse):
# use pydantic json serialization instead of the stock `json.dumps`,
# so that we can handle serializing pydantic models like AnyHttpUrl
def render(self, content: Any) -> bytes:
return content.model_dump_json(exclude_none=True).encode("utf-8")

View File

@@ -0,0 +1,3 @@
"""
Middleware for MCP authorization.
"""

View File

@@ -0,0 +1,48 @@
import contextvars
from starlette.types import ASGIApp, Receive, Scope, Send
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from mcp.server.auth.provider import AccessToken
# Create a contextvar to store the authenticated user
# The default is None, indicating no authenticated user is present
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
def get_access_token() -> AccessToken | None:
"""
Get the access token from the current context.
Returns:
The access token if an authenticated user is available, None otherwise.
"""
auth_user = auth_context_var.get()
return auth_user.access_token if auth_user else None
class AuthContextMiddleware:
"""
Middleware that extracts the authenticated user from the request
and sets it in a contextvar for easy access throughout the request lifecycle.
This middleware should be added after the AuthenticationMiddleware in the
middleware stack to ensure that the user is properly authenticated before
being stored in the context.
"""
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send):
user = scope.get("user")
if isinstance(user, AuthenticatedUser):
# Set the authenticated user in the contextvar
token = auth_context_var.set(user)
try:
await self.app(scope, receive, send)
finally:
auth_context_var.reset(token)
else:
# No authenticated user, just process the request
await self.app(scope, receive, send)

View File

@@ -0,0 +1,128 @@
import json
import time
from typing import Any
from pydantic import AnyHttpUrl
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
from starlette.requests import HTTPConnection
from starlette.types import Receive, Scope, Send
from mcp.server.auth.provider import AccessToken, TokenVerifier
class AuthenticatedUser(SimpleUser):
"""User with authentication info."""
def __init__(self, auth_info: AccessToken):
super().__init__(auth_info.client_id)
self.access_token = auth_info
self.scopes = auth_info.scopes
class BearerAuthBackend(AuthenticationBackend):
"""
Authentication backend that validates Bearer tokens using a TokenVerifier.
"""
def __init__(self, token_verifier: TokenVerifier):
self.token_verifier = token_verifier
async def authenticate(self, conn: HTTPConnection):
auth_header = next(
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
None,
)
if not auth_header or not auth_header.lower().startswith("bearer "):
return None
token = auth_header[7:] # Remove "Bearer " prefix
# Validate the token with the verifier
auth_info = await self.token_verifier.verify_token(token)
if not auth_info:
return None
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
return None
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
class RequireAuthMiddleware:
"""
Middleware that requires a valid Bearer token in the Authorization header.
This will validate the token with the auth provider and store the resulting
auth info in the request state.
"""
def __init__(
self,
app: Any,
required_scopes: list[str],
resource_metadata_url: AnyHttpUrl | None = None,
):
"""
Initialize the middleware.
Args:
app: ASGI application
required_scopes: List of scopes that the token must have
resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header
"""
self.app = app
self.required_scopes = required_scopes
self.resource_metadata_url = resource_metadata_url
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
auth_user = scope.get("user")
if not isinstance(auth_user, AuthenticatedUser):
await self._send_auth_error(
send, status_code=401, error="invalid_token", description="Authentication required"
)
return
auth_credentials = scope.get("auth")
for required_scope in self.required_scopes:
# auth_credentials should always be provided; this is just paranoia
if auth_credentials is None or required_scope not in auth_credentials.scopes:
await self._send_auth_error(
send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}"
)
return
await self.app(scope, receive, send)
async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None:
"""Send an authentication error response with WWW-Authenticate header."""
# Build WWW-Authenticate header value
www_auth_parts = [f'error="{error}"', f'error_description="{description}"']
if self.resource_metadata_url: # pragma: no cover
www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"')
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
# Send response
body = {"error": error, "error_description": description}
body_bytes = json.dumps(body).encode()
await send(
{
"type": "http.response.start",
"status": status_code,
"headers": [
(b"content-type", b"application/json"),
(b"content-length", str(len(body_bytes)).encode()),
(b"www-authenticate", www_authenticate.encode()),
],
}
)
await send(
{
"type": "http.response.body",
"body": body_bytes,
}
)

View File

@@ -0,0 +1,115 @@
import base64
import binascii
import hmac
import time
from typing import Any
from urllib.parse import unquote
from starlette.requests import Request
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.shared.auth import OAuthClientInformationFull
class AuthenticationError(Exception):
def __init__(self, message: str):
self.message = message # pragma: no cover
class ClientAuthenticator:
"""
ClientAuthenticator is a callable which validates requests from a client
application, used to verify /token calls.
If, during registration, the client requested to be issued a secret, the
authenticator asserts that /token calls must be authenticated with
that same token.
NOTE: clients can opt for no authentication during registration, in which case this
logic is skipped.
"""
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""
Initialize the dependency.
Args:
provider: Provider to look up client information
"""
self.provider = provider
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
"""
Authenticate a client from an HTTP request.
Extracts client credentials from the appropriate location based on the
client's registered authentication method and validates them.
Args:
request: The HTTP request containing client credentials
Returns:
The authenticated client information
Raises:
AuthenticationError: If authentication fails
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
raise AuthenticationError("Missing client_id")
client = await self.provider.get_client(str(client_id))
if not client:
raise AuthenticationError("Invalid client_id") # pragma: no cover
request_client_secret: str | None = None
auth_header = request.headers.get("Authorization", "")
if client.token_endpoint_auth_method == "client_secret_basic":
if not auth_header.startswith("Basic "):
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
try:
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
if ":" not in decoded:
raise ValueError("Invalid Basic auth format")
basic_client_id, request_client_secret = decoded.split(":", 1)
# URL-decode both parts per RFC 6749 Section 2.3.1
basic_client_id = unquote(basic_client_id)
request_client_secret = unquote(request_client_secret)
if basic_client_id != client_id:
raise AuthenticationError("Client ID mismatch in Basic auth")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid Basic authentication header")
elif client.token_endpoint_auth_method == "client_secret_post":
raw_form_data = form_data.get("client_secret")
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
if isinstance(raw_form_data, str):
request_client_secret = str(raw_form_data)
elif client.token_endpoint_auth_method == "none":
request_client_secret = None
else:
raise AuthenticationError( # pragma: no cover
f"Unsupported auth method: {client.token_endpoint_auth_method}"
)
# If client from the store expects a secret, validate that the request provides
# that secret
if client.client_secret: # pragma: no branch
if not request_client_secret:
raise AuthenticationError("Client secret is required") # pragma: no cover
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
# arguments to bytes.
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
raise AuthenticationError("Invalid client_secret") # pragma: no cover
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
raise AuthenticationError("Client secret has expired") # pragma: no cover
return client

View File

@@ -0,0 +1,301 @@
from dataclasses import dataclass
from typing import Generic, Literal, Protocol, TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from pydantic import AnyUrl, BaseModel
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
class AuthorizationParams(BaseModel):
state: str | None
scopes: list[str] | None
code_challenge: str
redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool
resource: str | None = None # RFC 8707 resource indicator
class AuthorizationCode(BaseModel):
code: str
scopes: list[str]
expires_at: float
client_id: str
code_challenge: str
redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool
resource: str | None = None # RFC 8707 resource indicator
class RefreshToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
class AccessToken(BaseModel):
token: str
client_id: str
scopes: list[str]
expires_at: int | None = None
resource: str | None = None # RFC 8707 resource indicator
RegistrationErrorCode = Literal[
"invalid_redirect_uri",
"invalid_client_metadata",
"invalid_software_statement",
"unapproved_software_statement",
]
@dataclass(frozen=True)
class RegistrationError(Exception):
error: RegistrationErrorCode
error_description: str | None = None
AuthorizationErrorCode = Literal[
"invalid_request",
"unauthorized_client",
"access_denied",
"unsupported_response_type",
"invalid_scope",
"server_error",
"temporarily_unavailable",
]
@dataclass(frozen=True)
class AuthorizeError(Exception):
error: AuthorizationErrorCode
error_description: str | None = None
TokenErrorCode = Literal[
"invalid_request",
"invalid_client",
"invalid_grant",
"unauthorized_client",
"unsupported_grant_type",
"invalid_scope",
]
@dataclass(frozen=True)
class TokenError(Exception):
error: TokenErrorCode
error_description: str | None = None
class TokenVerifier(Protocol):
"""Protocol for verifying bearer tokens."""
async def verify_token(self, token: str) -> AccessToken | None:
"""Verify a bearer token and return access info if valid."""
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
# OK to add fields to subclasses which should not be exposed externally.
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
"""
Retrieves client information by client ID.
Implementors MAY raise NotImplementedError if dynamic client registration is
disabled in ClientRegistrationOptions.
Args:
client_id: The ID of the client to retrieve.
Returns:
The client information, or None if the client does not exist.
"""
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
"""
Saves client information as part of registering it.
Implementors MAY raise NotImplementedError if dynamic client registration is
disabled in ClientRegistrationOptions.
Args:
client_info: The client metadata to register.
Raises:
RegistrationError: If the client metadata is invalid.
"""
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
"""
Called as part of the /authorize endpoint, and returns a URL that the client
will be redirected to.
Many MCP implementations will redirect to a third-party provider to perform
a second OAuth exchange with that provider. In this sort of setup, the client
has an OAuth connection with the MCP server, and the MCP server has an OAuth
connection with the 3rd-party provider. At the end of this flow, the client
should be redirected to the redirect_uri from params.redirect_uri.
+--------+ +------------+ +-------------------+
| | | | | |
| Client | --> | MCP Server | --> | 3rd Party OAuth |
| | | | | Server |
+--------+ +------------+ +-------------------+
| ^ |
+------------+ | | |
| | | | Redirect |
|redirect_uri|<-----+ +------------------+
| |
+------------+
Implementations will need to define another handler on the MCP server return
flow to perform the second redirect, and generate and store an authorization
code as part of completing the OAuth authorization step.
Implementations SHOULD generate an authorization code with at least 160 bits of
entropy,
and MUST generate an authorization code with at least 128 bits of entropy.
See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10.
Args:
client: The client requesting authorization.
params: The parameters of the authorization request.
Returns:
A URL to redirect the client to for authorization.
Raises:
AuthorizeError: If the authorization request is invalid.
"""
...
async def load_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: str
) -> AuthorizationCodeT | None:
"""
Loads an AuthorizationCode by its code.
Args:
client: The client that requested the authorization code.
authorization_code: The authorization code to get the challenge for.
Returns:
The AuthorizationCode, or None if not found
"""
...
async def exchange_authorization_code(
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
) -> OAuthToken:
"""
Exchanges an authorization code for an access token and refresh token.
Args:
client: The client exchanging the authorization code.
authorization_code: The authorization code to exchange.
Returns:
The OAuth token, containing access and refresh tokens.
Raises:
TokenError: If the request is invalid
"""
...
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
"""
Loads a RefreshToken by its token string.
Args:
client: The client that is requesting to load the refresh token.
refresh_token: The refresh token string to load.
Returns:
The RefreshToken object if found, or None if not found.
"""
...
async def exchange_refresh_token(
self,
client: OAuthClientInformationFull,
refresh_token: RefreshTokenT,
scopes: list[str],
) -> OAuthToken:
"""
Exchanges a refresh token for an access token and refresh token.
Implementations SHOULD rotate both the access token and refresh token.
Args:
client: The client exchanging the refresh token.
refresh_token: The refresh token to exchange.
scopes: Optional scopes to request with the new access token.
Returns:
The OAuth token, containing access and refresh tokens.
Raises:
TokenError: If the request is invalid
"""
...
async def load_access_token(self, token: str) -> AccessTokenT | None:
"""
Loads an access token by its token.
Args:
token: The access token to verify.
Returns:
The AuthInfo, or None if the token is invalid.
"""
async def revoke_token(
self,
token: AccessTokenT | RefreshTokenT,
) -> None:
"""
Revokes an access or refresh token.
If the given token is invalid or already revoked, this method should do nothing.
Implementations SHOULD revoke both the access token and its corresponding
refresh token, regardless of which of the access token or refresh token is
provided.
Args:
token: the token to revoke
"""
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
parsed_uri = urlparse(redirect_uri_base)
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs]
for k, v in params.items():
if v is not None:
query_params.append((k, v))
redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params)))
return redirect_uri
class ProviderTokenVerifier(TokenVerifier):
"""Token verifier that uses an OAuthAuthorizationServerProvider.
This is provided for backwards compatibility with existing auth_server_provider
configurations. For new implementations using AS/RS separation, consider using
the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier.
"""
def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"):
self.provider = provider
async def verify_token(self, token: str) -> AccessToken | None:
"""Verify token using the provider's load_access_token method."""
return await self.provider.load_access_token(token)

View File

@@ -0,0 +1,253 @@
from collections.abc import Awaitable, Callable
from typing import Any
from urllib.parse import urlparse
from pydantic import AnyHttpUrl
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route, request_response # type: ignore
from starlette.types import ASGIApp
from mcp.server.auth.handlers.authorize import AuthorizationHandler
from mcp.server.auth.handlers.metadata import MetadataHandler
from mcp.server.auth.handlers.register import RegistrationHandler
from mcp.server.auth.handlers.revoke import RevocationHandler
from mcp.server.auth.handlers.token import TokenHandler
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
from mcp.shared.auth import OAuthMetadata
def validate_issuer_url(url: AnyHttpUrl):
"""
Validate that the issuer URL meets OAuth 2.0 requirements.
Args:
url: The issuer URL to validate
Raises:
ValueError: If the issuer URL is invalid
"""
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
if (
url.scheme != "https"
and url.host != "localhost"
and (url.host is not None and not url.host.startswith("127.0.0.1"))
):
raise ValueError("Issuer URL must be HTTPS") # pragma: no cover
# No fragments or query parameters allowed
if url.fragment:
raise ValueError("Issuer URL must not have a fragment") # pragma: no cover
if url.query:
raise ValueError("Issuer URL must not have a query string") # pragma: no cover
AUTHORIZATION_PATH = "/authorize"
TOKEN_PATH = "/token"
REGISTRATION_PATH = "/register"
REVOCATION_PATH = "/revoke"
def cors_middleware(
handler: Callable[[Request], Response | Awaitable[Response]],
allow_methods: list[str],
) -> ASGIApp:
cors_app = CORSMiddleware(
app=request_response(handler),
allow_origins="*",
allow_methods=allow_methods,
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
)
return cors_app
def create_auth_routes(
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
issuer_url: AnyHttpUrl,
service_documentation_url: AnyHttpUrl | None = None,
client_registration_options: ClientRegistrationOptions | None = None,
revocation_options: RevocationOptions | None = None,
) -> list[Route]:
validate_issuer_url(issuer_url)
client_registration_options = client_registration_options or ClientRegistrationOptions()
revocation_options = revocation_options or RevocationOptions()
metadata = build_metadata(
issuer_url,
service_documentation_url,
client_registration_options,
revocation_options,
)
client_authenticator = ClientAuthenticator(provider)
# Create routes
# Allow CORS requests for endpoints meant to be hit by the OAuth client
# (with the client secret). This is intended to support things like MCP Inspector,
# where the client runs in a web browser.
routes = [
Route(
"/.well-known/oauth-authorization-server",
endpoint=cors_middleware(
MetadataHandler(metadata).handle,
["GET", "OPTIONS"],
),
methods=["GET", "OPTIONS"],
),
Route(
AUTHORIZATION_PATH,
# do not allow CORS for authorization endpoint;
# clients should just redirect to this
endpoint=AuthorizationHandler(provider).handle,
methods=["GET", "POST"],
),
Route(
TOKEN_PATH,
endpoint=cors_middleware(
TokenHandler(provider, client_authenticator).handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
),
]
if client_registration_options.enabled: # pragma: no branch
registration_handler = RegistrationHandler(
provider,
options=client_registration_options,
)
routes.append(
Route(
REGISTRATION_PATH,
endpoint=cors_middleware(
registration_handler.handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
)
)
if revocation_options.enabled: # pragma: no branch
revocation_handler = RevocationHandler(provider, client_authenticator)
routes.append(
Route(
REVOCATION_PATH,
endpoint=cors_middleware(
revocation_handler.handle,
["POST", "OPTIONS"],
),
methods=["POST", "OPTIONS"],
)
)
return routes
def build_metadata(
issuer_url: AnyHttpUrl,
service_documentation_url: AnyHttpUrl | None,
client_registration_options: ClientRegistrationOptions,
revocation_options: RevocationOptions,
) -> OAuthMetadata:
authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH)
token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH)
# Create metadata
metadata = OAuthMetadata(
issuer=issuer_url,
authorization_endpoint=authorization_url,
token_endpoint=token_url,
scopes_supported=client_registration_options.valid_scopes,
response_types_supported=["code"],
response_modes_supported=None,
grant_types_supported=["authorization_code", "refresh_token"],
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
token_endpoint_auth_signing_alg_values_supported=None,
service_documentation=service_documentation_url,
ui_locales_supported=None,
op_policy_uri=None,
op_tos_uri=None,
introspection_endpoint=None,
code_challenge_methods_supported=["S256"],
)
# Add registration endpoint if supported
if client_registration_options.enabled: # pragma: no branch
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
# Add revocation endpoint if supported
if revocation_options.enabled: # pragma: no branch
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]
return metadata
def build_resource_metadata_url(resource_server_url: AnyHttpUrl) -> AnyHttpUrl:
"""
Build RFC 9728 compliant protected resource metadata URL.
Inserts /.well-known/oauth-protected-resource between host and resource path
as specified in RFC 9728 §3.1.
Args:
resource_server_url: The resource server URL (e.g., https://example.com/mcp)
Returns:
The metadata URL (e.g., https://example.com/.well-known/oauth-protected-resource/mcp)
"""
parsed = urlparse(str(resource_server_url))
# Handle trailing slash: if path is just "/", treat as empty
resource_path = parsed.path if parsed.path != "/" else ""
return AnyHttpUrl(f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-protected-resource{resource_path}")
def create_protected_resource_routes(
resource_url: AnyHttpUrl,
authorization_servers: list[AnyHttpUrl],
scopes_supported: list[str] | None = None,
resource_name: str | None = None,
resource_documentation: AnyHttpUrl | None = None,
) -> list[Route]:
"""
Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728).
Args:
resource_url: The URL of this resource server
authorization_servers: List of authorization servers that can issue tokens
scopes_supported: Optional list of scopes supported by this resource
Returns:
List of Starlette routes for protected resource metadata
"""
from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler
from mcp.shared.auth import ProtectedResourceMetadata
metadata = ProtectedResourceMetadata(
resource=resource_url,
authorization_servers=authorization_servers,
scopes_supported=scopes_supported,
resource_name=resource_name,
resource_documentation=resource_documentation,
# bearer_methods_supported defaults to ["header"] in the model
)
handler = ProtectedResourceMetadataHandler(metadata)
# RFC 9728 §3.1: Register route at /.well-known/oauth-protected-resource + resource path
metadata_url = build_resource_metadata_url(resource_url)
# Extract just the path part for route registration
parsed = urlparse(str(metadata_url))
well_known_path = parsed.path
return [
Route(
well_known_path,
endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]),
methods=["GET", "OPTIONS"],
)
]

View File

@@ -0,0 +1,30 @@
from pydantic import AnyHttpUrl, BaseModel, Field
class ClientRegistrationOptions(BaseModel):
enabled: bool = False
client_secret_expiry_seconds: int | None = None
valid_scopes: list[str] | None = None
default_scopes: list[str] | None = None
class RevocationOptions(BaseModel):
enabled: bool = False
class AuthSettings(BaseModel):
issuer_url: AnyHttpUrl = Field(
...,
description="OAuth authorization server URL that issues tokens for this resource server.",
)
service_documentation_url: AnyHttpUrl | None = None
client_registration_options: ClientRegistrationOptions | None = None
revocation_options: RevocationOptions | None = None
required_scopes: list[str] | None = None
# Resource Server settings (when operating as RS only)
resource_server_url: AnyHttpUrl | None = Field(
...,
description="The URL of the MCP server to be used as the resource identifier "
"and base route to look up OAuth Protected Resource Metadata.",
)

View File

@@ -0,0 +1,190 @@
"""Elicitation utilities for MCP servers."""
from __future__ import annotations
import types
from collections.abc import Sequence
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin
from pydantic import BaseModel
from mcp.server.session import ServerSession
from mcp.types import RequestId
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
"""Result when user accepts the elicitation."""
action: Literal["accept"] = "accept"
data: ElicitSchemaModelT
class DeclinedElicitation(BaseModel):
"""Result when user declines the elicitation."""
action: Literal["decline"] = "decline"
class CancelledElicitation(BaseModel):
"""Result when user cancels the elicitation."""
action: Literal["cancel"] = "cancel"
ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation
class AcceptedUrlElicitation(BaseModel):
"""Result when user accepts a URL mode elicitation."""
action: Literal["accept"] = "accept"
UrlElicitationResult = AcceptedUrlElicitation | DeclinedElicitation | CancelledElicitation
# Primitive types allowed in elicitation schemas
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)
def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
"""Validate that a Pydantic model only contains primitive field types."""
for field_name, field_info in schema.model_fields.items():
annotation = field_info.annotation
if annotation is None or annotation is types.NoneType: # pragma: no cover
continue
elif _is_primitive_field(annotation):
continue
elif _is_string_sequence(annotation):
continue
else:
raise TypeError(
f"Elicitation schema field '{field_name}' must be a primitive type "
f"{_ELICITATION_PRIMITIVE_TYPES}, a sequence of strings (list[str], etc.), "
f"or Optional of these types. Nested models and complex types are not allowed."
)
def _is_string_sequence(annotation: type) -> bool:
"""Check if annotation is a sequence of strings (list[str], Sequence[str], etc)."""
origin = get_origin(annotation)
# Check if it's a sequence-like type with str elements
if origin:
try:
if issubclass(origin, Sequence):
args = get_args(annotation)
# Should have single str type arg
return len(args) == 1 and args[0] is str
except TypeError: # pragma: no cover
# origin is not a class, so it can't be a subclass of Sequence
pass
return False
def _is_primitive_field(annotation: type) -> bool:
"""Check if a field is a primitive type allowed in elicitation schemas."""
# Handle basic primitive types
if annotation in _ELICITATION_PRIMITIVE_TYPES:
return True
# Handle Union types
origin = get_origin(annotation)
if origin is Union or origin is types.UnionType:
args = get_args(annotation)
# All args must be primitive types, None, or string sequences
return all(
arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES or _is_string_sequence(arg) for arg in args
)
return False
async def elicit_with_validation(
session: ServerSession,
message: str,
schema: type[ElicitSchemaModelT],
related_request_id: RequestId | None = None,
) -> ElicitationResult[ElicitSchemaModelT]:
"""Elicit information from the client/user with schema validation (form mode).
This method can be used to interactively ask for additional information from the
client within a tool's execution. The client might display the message to the
user and collect a response according to the provided schema. Or in case a
client is an agent, it might decide how to handle the elicitation -- either by asking
the user or automatically generating a response.
For sensitive data like credentials or OAuth flows, use elicit_url() instead.
"""
# Validate that schema only contains primitive types and fail loudly if not
_validate_elicitation_schema(schema)
json_schema = schema.model_json_schema()
result = await session.elicit_form(
message=message,
requestedSchema=json_schema,
related_request_id=related_request_id,
)
if result.action == "accept" and result.content is not None:
# Validate and parse the content using the schema
validated_data = schema.model_validate(result.content)
return AcceptedElicitation(data=validated_data)
elif result.action == "decline":
return DeclinedElicitation()
elif result.action == "cancel": # pragma: no cover
return CancelledElicitation()
else: # pragma: no cover
# This should never happen, but handle it just in case
raise ValueError(f"Unexpected elicitation action: {result.action}")
async def elicit_url(
session: ServerSession,
message: str,
url: str,
elicitation_id: str,
related_request_id: RequestId | None = None,
) -> UrlElicitationResult:
"""Elicit information from the user via out-of-band URL navigation (URL mode).
This method directs the user to an external URL where sensitive interactions can
occur without passing data through the MCP client. Use this for:
- Collecting sensitive credentials (API keys, passwords)
- OAuth authorization flows with third-party services
- Payment and subscription flows
- Any interaction where data should not pass through the LLM context
The response indicates whether the user consented to navigate to the URL.
The actual interaction happens out-of-band. When the elicitation completes,
the server should send an ElicitCompleteNotification to notify the client.
Args:
session: The server session
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
related_request_id: Optional ID of the request that triggered this elicitation
Returns:
UrlElicitationResult indicating accept, decline, or cancel
"""
result = await session.elicit_url(
message=message,
url=url,
elicitation_id=elicitation_id,
related_request_id=related_request_id,
)
if result.action == "accept":
return AcceptedUrlElicitation()
elif result.action == "decline":
return DeclinedElicitation()
elif result.action == "cancel":
return CancelledElicitation()
else: # pragma: no cover
# This should never happen, but handle it just in case
raise ValueError(f"Unexpected elicitation action: {result.action}")

View File

@@ -0,0 +1,11 @@
"""
Server-side experimental features.
WARNING: These APIs are experimental and may change without notice.
Import directly from submodules:
- mcp.server.experimental.task_context.ServerTaskContext
- mcp.server.experimental.task_support.TaskSupport
- mcp.server.experimental.task_result_handler.TaskResultHandler
- mcp.server.experimental.request_context.Experimental
"""

View File

@@ -0,0 +1,238 @@
"""
Experimental request context features.
This module provides the Experimental class which gives access to experimental
features within a request context, such as task-augmented request handling.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from mcp.server.experimental.task_context import ServerTaskContext
from mcp.server.experimental.task_support import TaskSupport
from mcp.server.session import ServerSession
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.helpers import MODEL_IMMEDIATE_RESPONSE_KEY, is_terminal
from mcp.types import (
METHOD_NOT_FOUND,
TASK_FORBIDDEN,
TASK_REQUIRED,
ClientCapabilities,
CreateTaskResult,
ErrorData,
Result,
TaskExecutionMode,
TaskMetadata,
Tool,
)
@dataclass
class Experimental:
"""
Experimental features context for task-augmented requests.
Provides helpers for validating task execution compatibility and
running tasks with automatic lifecycle management.
WARNING: This API is experimental and may change without notice.
"""
task_metadata: TaskMetadata | None = None
_client_capabilities: ClientCapabilities | None = field(default=None, repr=False)
_session: ServerSession | None = field(default=None, repr=False)
_task_support: TaskSupport | None = field(default=None, repr=False)
@property
def is_task(self) -> bool:
"""Check if this request is task-augmented."""
return self.task_metadata is not None
@property
def client_supports_tasks(self) -> bool:
"""Check if the client declared task support."""
if self._client_capabilities is None:
return False
return self._client_capabilities.tasks is not None
def validate_task_mode(
self,
tool_task_mode: TaskExecutionMode | None,
*,
raise_error: bool = True,
) -> ErrorData | None:
"""
Validate that the request is compatible with the tool's task execution mode.
Per MCP spec:
- "required": Clients MUST invoke as task. Server returns -32601 if not.
- "forbidden" (or None): Clients MUST NOT invoke as task. Server returns -32601 if they do.
- "optional": Either is acceptable.
Args:
tool_task_mode: The tool's execution.taskSupport value
("forbidden", "optional", "required", or None)
raise_error: If True, raises McpError on validation failure. If False, returns ErrorData.
Returns:
None if valid, ErrorData if invalid and raise_error=False
Raises:
McpError: If invalid and raise_error=True
"""
mode = tool_task_mode or TASK_FORBIDDEN
error: ErrorData | None = None
if mode == TASK_REQUIRED and not self.is_task:
error = ErrorData(
code=METHOD_NOT_FOUND,
message="This tool requires task-augmented invocation",
)
elif mode == TASK_FORBIDDEN and self.is_task:
error = ErrorData(
code=METHOD_NOT_FOUND,
message="This tool does not support task-augmented invocation",
)
if error is not None and raise_error:
raise McpError(error)
return error
def validate_for_tool(
self,
tool: Tool,
*,
raise_error: bool = True,
) -> ErrorData | None:
"""
Validate that the request is compatible with the given tool.
Convenience wrapper around validate_task_mode that extracts the mode from a Tool.
Args:
tool: The Tool definition
raise_error: If True, raises McpError on validation failure.
Returns:
None if valid, ErrorData if invalid and raise_error=False
"""
mode = tool.execution.taskSupport if tool.execution else None
return self.validate_task_mode(mode, raise_error=raise_error)
def can_use_tool(self, tool_task_mode: TaskExecutionMode | None) -> bool:
"""
Check if this client can use a tool with the given task mode.
Useful for filtering tool lists or providing warnings.
Returns False if tool requires "required" but client doesn't support tasks.
Args:
tool_task_mode: The tool's execution.taskSupport value
Returns:
True if the client can use this tool, False otherwise
"""
mode = tool_task_mode or TASK_FORBIDDEN
if mode == TASK_REQUIRED and not self.client_supports_tasks:
return False
return True
async def run_task(
self,
work: Callable[[ServerTaskContext], Awaitable[Result]],
*,
task_id: str | None = None,
model_immediate_response: str | None = None,
) -> CreateTaskResult:
"""
Create a task, spawn background work, and return CreateTaskResult immediately.
This is the recommended way to handle task-augmented tool calls. It:
1. Creates a task in the store
2. Spawns the work function in a background task
3. Returns CreateTaskResult immediately
The work function receives a ServerTaskContext with:
- elicit() for sending elicitation requests
- create_message() for sampling requests
- update_status() for progress updates
- complete()/fail() for finishing the task
When work() returns a Result, the task is auto-completed with that result.
If work() raises an exception, the task is auto-failed.
Args:
work: Async function that does the actual work
task_id: Optional task ID (generated if not provided)
model_immediate_response: Optional string to include in _meta as
io.modelcontextprotocol/model-immediate-response
Returns:
CreateTaskResult to return to the client
Raises:
RuntimeError: If task support is not enabled or task_metadata is missing
Example:
@server.call_tool()
async def handle_tool(name: str, args: dict):
ctx = server.request_context
async def work(task: ServerTaskContext) -> CallToolResult:
result = await task.elicit(
message="Are you sure?",
requestedSchema={"type": "object", ...}
)
confirmed = result.content.get("confirm", False)
return CallToolResult(content=[TextContent(text="Done" if confirmed else "Cancelled")])
return await ctx.experimental.run_task(work)
WARNING: This API is experimental and may change without notice.
"""
if self._task_support is None:
raise RuntimeError("Task support not enabled. Call server.experimental.enable_tasks() first.")
if self._session is None:
raise RuntimeError("Session not available.")
if self.task_metadata is None:
raise RuntimeError(
"Request is not task-augmented (no task field in params). "
"The client must send a task-augmented request."
)
support = self._task_support
# Access task_group via TaskSupport - raises if not in run() context
task_group = support.task_group
task = await support.store.create_task(self.task_metadata, task_id)
task_ctx = ServerTaskContext(
task=task,
store=support.store,
session=self._session,
queue=support.queue,
handler=support.handler,
)
async def execute() -> None:
try:
result = await work(task_ctx)
if not is_terminal(task_ctx.task.status):
await task_ctx.complete(result)
except Exception as e:
if not is_terminal(task_ctx.task.status):
await task_ctx.fail(str(e))
task_group.start_soon(execute)
meta: dict[str, Any] | None = None
if model_immediate_response is not None:
meta = {MODEL_IMMEDIATE_RESPONSE_KEY: model_immediate_response}
return CreateTaskResult(task=task, **{"_meta": meta} if meta else {})

View File

@@ -0,0 +1,220 @@
"""
Experimental server session features for server→client task operations.
This module provides the server-side equivalent of ExperimentalClientFeatures,
allowing the server to send task-augmented requests to the client and poll for results.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypeVar
import mcp.types as types
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.experimental.tasks.capabilities import (
require_task_augmented_elicitation,
require_task_augmented_sampling,
)
from mcp.shared.experimental.tasks.polling import poll_until_terminal
if TYPE_CHECKING:
from mcp.server.session import ServerSession
ResultT = TypeVar("ResultT", bound=types.Result)
class ExperimentalServerSessionFeatures:
"""
Experimental server session features for server→client task operations.
This provides the server-side equivalent of ExperimentalClientFeatures,
allowing the server to send task-augmented requests to the client and
poll for results.
WARNING: These APIs are experimental and may change without notice.
Access via session.experimental:
result = await session.experimental.elicit_as_task(...)
"""
def __init__(self, session: "ServerSession") -> None:
self._session = session
async def get_task(self, task_id: str) -> types.GetTaskResult:
"""
Send tasks/get to the client to get task status.
Args:
task_id: The task identifier
Returns:
GetTaskResult containing the task status
"""
return await self._session.send_request(
types.ServerRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))),
types.GetTaskResult,
)
async def get_task_result(
self,
task_id: str,
result_type: type[ResultT],
) -> ResultT:
"""
Send tasks/result to the client to retrieve the final result.
Args:
task_id: The task identifier
result_type: The expected result type
Returns:
The task result, validated against result_type
"""
return await self._session.send_request(
types.ServerRequest(types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(taskId=task_id))),
result_type,
)
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
"""
Poll a client task until it reaches terminal status.
Yields GetTaskResult for each poll, allowing the caller to react to
status changes. Exits when task reaches a terminal status.
Respects the pollInterval hint from the client.
Args:
task_id: The task identifier
Yields:
GetTaskResult for each poll
"""
async for status in poll_until_terminal(self.get_task, task_id):
yield status
async def elicit_as_task(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
*,
ttl: int = 60000,
) -> types.ElicitResult:
"""
Send a task-augmented elicitation to the client and poll until complete.
The client will create a local task, process the elicitation asynchronously,
and return the result when ready. This method handles the full flow:
1. Send elicitation with task field
2. Receive CreateTaskResult from client
3. Poll client's task until terminal
4. Retrieve and return the final ElicitResult
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response
ttl: Task time-to-live in milliseconds
Returns:
The client's elicitation response
Raises:
McpError: If client doesn't support task-augmented elicitation
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_elicitation(client_caps)
create_result = await self._session.send_request(
types.ServerRequest(
types.ElicitRequest(
params=types.ElicitRequestFormParams(
message=message,
requestedSchema=requestedSchema,
task=types.TaskMetadata(ttl=ttl),
)
)
),
types.CreateTaskResult,
)
task_id = create_result.task.taskId
async for _ in self.poll_task(task_id):
pass
return await self.get_task_result(task_id, types.ElicitResult)
async def create_message_as_task(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
ttl: int = 60000,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
) -> types.CreateMessageResult:
"""
Send a task-augmented sampling request and poll until complete.
The client will create a local task, process the sampling request
asynchronously, and return the result when ready.
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
ttl: Task time-to-live in milliseconds
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support task-augmented sampling or tools
ValueError: If tool_use or tool_result message structure is invalid
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_sampling(client_caps)
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
create_result = await self._session.send_request(
types.ServerRequest(
types.CreateMessageRequest(
params=types.CreateMessageRequestParams(
messages=messages,
maxTokens=max_tokens,
systemPrompt=system_prompt,
includeContext=include_context,
temperature=temperature,
stopSequences=stop_sequences,
metadata=metadata,
modelPreferences=model_preferences,
tools=tools,
toolChoice=tool_choice,
task=types.TaskMetadata(ttl=ttl),
)
)
),
types.CreateTaskResult,
)
task_id = create_result.task.taskId
async for _ in self.poll_task(task_id):
pass
return await self.get_task_result(task_id, types.CreateMessageResult)

View File

@@ -0,0 +1,612 @@
"""
ServerTaskContext - Server-integrated task context with elicitation and sampling.
This wraps the pure TaskContext and adds server-specific functionality:
- Elicitation (task.elicit())
- Sampling (task.create_message())
- Status notifications
"""
from typing import Any
import anyio
from mcp.server.experimental.task_result_handler import TaskResultHandler
from mcp.server.session import ServerSession
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.capabilities import (
require_task_augmented_elicitation,
require_task_augmented_sampling,
)
from mcp.shared.experimental.tasks.context import TaskContext
from mcp.shared.experimental.tasks.message_queue import QueuedMessage, TaskMessageQueue
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.types import (
INVALID_REQUEST,
TASK_STATUS_INPUT_REQUIRED,
TASK_STATUS_WORKING,
ClientCapabilities,
CreateMessageResult,
CreateTaskResult,
ElicitationCapability,
ElicitRequestedSchema,
ElicitResult,
ErrorData,
IncludeContext,
ModelPreferences,
RequestId,
Result,
SamplingCapability,
SamplingMessage,
ServerNotification,
Task,
TaskMetadata,
TaskStatusNotification,
TaskStatusNotificationParams,
Tool,
ToolChoice,
)
class ServerTaskContext:
"""
Server-integrated task context with elicitation and sampling.
This wraps a pure TaskContext and adds server-specific functionality:
- elicit() for sending elicitation requests to the client
- create_message() for sampling requests
- Status notifications via the session
Example:
async def my_task_work(task: ServerTaskContext) -> CallToolResult:
await task.update_status("Starting...")
result = await task.elicit(
message="Continue?",
requestedSchema={"type": "object", "properties": {"ok": {"type": "boolean"}}}
)
if result.content.get("ok"):
return CallToolResult(content=[TextContent(text="Done!")])
else:
return CallToolResult(content=[TextContent(text="Cancelled")])
"""
def __init__(
self,
*,
task: Task,
store: TaskStore,
session: ServerSession,
queue: TaskMessageQueue,
handler: TaskResultHandler | None = None,
):
"""
Create a ServerTaskContext.
Args:
task: The Task object
store: The task store
session: The server session
queue: The message queue for elicitation/sampling
handler: The result handler for response routing (required for elicit/create_message)
"""
self._ctx = TaskContext(task=task, store=store)
self._session = session
self._queue = queue
self._handler = handler
self._store = store
# Delegate pure properties to inner context
@property
def task_id(self) -> str:
"""The task identifier."""
return self._ctx.task_id
@property
def task(self) -> Task:
"""The current task state."""
return self._ctx.task
@property
def is_cancelled(self) -> bool:
"""Whether cancellation has been requested."""
return self._ctx.is_cancelled
def request_cancellation(self) -> None:
"""Request cancellation of this task."""
self._ctx.request_cancellation()
# Enhanced methods with notifications
async def update_status(self, message: str, *, notify: bool = True) -> None:
"""
Update the task's status message.
Args:
message: The new status message
notify: Whether to send a notification to the client
"""
await self._ctx.update_status(message)
if notify:
await self._send_notification()
async def complete(self, result: Result, *, notify: bool = True) -> None:
"""
Mark the task as completed with the given result.
Args:
result: The task result
notify: Whether to send a notification to the client
"""
await self._ctx.complete(result)
if notify:
await self._send_notification()
async def fail(self, error: str, *, notify: bool = True) -> None:
"""
Mark the task as failed with an error message.
Args:
error: The error message
notify: Whether to send a notification to the client
"""
await self._ctx.fail(error)
if notify:
await self._send_notification()
async def _send_notification(self) -> None:
"""Send a task status notification to the client."""
task = self._ctx.task
await self._session.send_notification(
ServerNotification(
TaskStatusNotification(
params=TaskStatusNotificationParams(
taskId=task.taskId,
status=task.status,
statusMessage=task.statusMessage,
createdAt=task.createdAt,
lastUpdatedAt=task.lastUpdatedAt,
ttl=task.ttl,
pollInterval=task.pollInterval,
)
)
)
)
# Server-specific methods: elicitation and sampling
def _check_elicitation_capability(self) -> None:
"""Check if the client supports elicitation."""
if not self._session.check_client_capability(ClientCapabilities(elicitation=ElicitationCapability())):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support elicitation capability",
)
)
def _check_sampling_capability(self) -> None:
"""Check if the client supports sampling."""
if not self._session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())):
raise McpError(
ErrorData(
code=INVALID_REQUEST,
message="Client does not support sampling capability",
)
)
async def elicit(
self,
message: str,
requestedSchema: ElicitRequestedSchema,
) -> ElicitResult:
"""
Send an elicitation request via the task message queue.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the elicitation request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
Returns:
The client's response
Raises:
McpError: If client doesn't support elicitation capability
"""
self._check_elicitation_capability()
if self._handler is None:
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
requestedSchema=requestedSchema,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
# Coverage can't track async exception handlers reliably.
# This path is tested in test_elicit_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def elicit_url(
self,
message: str,
url: str,
elicitation_id: str,
) -> ElicitResult:
"""
Send a URL mode elicitation request via the task message queue.
This directs the user to an external URL for out-of-band interactions
like OAuth flows, credential collection, or payment processing.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the elicitation request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
Returns:
The client's response indicating acceptance, decline, or cancellation
Raises:
McpError: If client doesn't support elicitation capability
RuntimeError: If handler is not configured
"""
self._check_elicitation_capability()
if self._handler is None:
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
message=message,
url=url,
elicitation_id=elicitation_id,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def create_message(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
tools: list[Tool] | None = None,
tool_choice: ToolChoice | None = None,
) -> CreateMessageResult:
"""
Send a sampling request via the task message queue.
This method:
1. Checks client capability
2. Updates task status to "input_required"
3. Queues the sampling request
4. Waits for the response (delivered via tasks/result round-trip)
5. Updates task status back to "working"
6. Returns the result
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support sampling capability or tools
ValueError: If tool_use or tool_result message structure is invalid
"""
self._check_sampling_capability()
client_caps = self._session.client_params.capabilities if self._session.client_params else None
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
if self._handler is None:
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build the request using session's helper
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
related_task_id=self.task_id,
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return CreateMessageResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
# Coverage can't track async exception handlers reliably.
# This path is tested in test_create_message_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def elicit_as_task(
self,
message: str,
requestedSchema: ElicitRequestedSchema,
*,
ttl: int = 60000,
) -> ElicitResult:
"""
Send a task-augmented elicitation via the queue, then poll client.
This is for use inside a task-augmented tool call when you want the client
to handle the elicitation as its own task. The elicitation request is queued
and delivered when the client calls tasks/result. After the client responds
with CreateTaskResult, we poll the client's task until complete.
Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure
ttl: Task time-to-live in milliseconds for the client's task
Returns:
The client's elicitation response
Raises:
McpError: If client doesn't support task-augmented elicitation
RuntimeError: If handler is not configured
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_elicitation(client_caps)
if self._handler is None:
raise RuntimeError("handler is required for elicit_as_task()")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
requestedSchema=requestedSchema,
related_task_id=self.task_id,
task=TaskMetadata(ttl=ttl),
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for initial response (CreateTaskResult from client)
response_data = await resolver.wait()
create_result = CreateTaskResult.model_validate(response_data)
client_task_id = create_result.task.taskId
# Poll the client's task using session.experimental
async for _ in self._session.experimental.poll_task(client_task_id):
pass
# Get final result from client
result = await self._session.experimental.get_task_result(
client_task_id,
ElicitResult,
)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return result
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise
async def create_message_as_task(
self,
messages: list[SamplingMessage],
*,
max_tokens: int,
ttl: int = 60000,
system_prompt: str | None = None,
include_context: IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: ModelPreferences | None = None,
tools: list[Tool] | None = None,
tool_choice: ToolChoice | None = None,
) -> CreateMessageResult:
"""
Send a task-augmented sampling request via the queue, then poll client.
This is for use inside a task-augmented tool call when you want the client
to handle the sampling as its own task. The request is queued and delivered
when the client calls tasks/result. After the client responds with
CreateTaskResult, we poll the client's task until complete.
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
ttl: Task time-to-live in milliseconds for the client's task
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
McpError: If client doesn't support task-augmented sampling or tools
ValueError: If tool_use or tool_result message structure is invalid
RuntimeError: If handler is not configured
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_sampling(client_caps)
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
if self._handler is None:
raise RuntimeError("handler is required for create_message_as_task()")
# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
# Build request WITH task field for task-augmented sampling
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
related_task_id=self.task_id,
task=TaskMetadata(ttl=ttl),
)
request_id: RequestId = request.id
resolver: Resolver[dict[str, Any]] = Resolver()
self._handler._pending_requests[request_id] = resolver # pyright: ignore[reportPrivateUsage]
queued = QueuedMessage(
type="request",
message=request,
resolver=resolver,
original_request_id=request_id,
)
await self._queue.enqueue(self.task_id, queued)
try:
# Wait for initial response (CreateTaskResult from client)
response_data = await resolver.wait()
create_result = CreateTaskResult.model_validate(response_data)
client_task_id = create_result.task.taskId
# Poll the client's task using session.experimental
async for _ in self._session.experimental.poll_task(client_task_id):
pass
# Get final result from client
result = await self._session.experimental.get_task_result(
client_task_id,
CreateMessageResult,
)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
return result
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
raise

View File

@@ -0,0 +1,235 @@
"""
TaskResultHandler - Integrated handler for tasks/result endpoint.
This implements the dequeue-send-wait pattern from the MCP Tasks spec:
1. Dequeue all pending messages for the task
2. Send them to the client via transport with relatedRequestId routing
3. Wait if task is not in terminal state
4. Return final result when task completes
This is the core of the task message queue pattern.
"""
import logging
from typing import Any
import anyio
from mcp.server.session import ServerSession
from mcp.shared.exceptions import McpError
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import (
INVALID_PARAMS,
ErrorData,
GetTaskPayloadRequest,
GetTaskPayloadResult,
JSONRPCMessage,
RelatedTaskMetadata,
RequestId,
)
logger = logging.getLogger(__name__)
class TaskResultHandler:
"""
Handler for tasks/result that implements the message queue pattern.
This handler:
1. Dequeues pending messages (elicitations, notifications) for the task
2. Sends them to the client via the response stream
3. Waits for responses and resolves them back to callers
4. Blocks until task reaches terminal state
5. Returns the final result
Usage:
# Create handler with store and queue
handler = TaskResultHandler(task_store, message_queue)
# Register it with the server
@server.experimental.get_task_result()
async def handle_task_result(req: GetTaskPayloadRequest) -> GetTaskPayloadResult:
ctx = server.request_context
return await handler.handle(req, ctx.session, ctx.request_id)
# Or use the convenience method
handler.register(server)
"""
def __init__(
self,
store: TaskStore,
queue: TaskMessageQueue,
):
self._store = store
self._queue = queue
# Map from internal request ID to resolver for routing responses
self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {}
async def send_message(
self,
session: ServerSession,
message: SessionMessage,
) -> None:
"""
Send a message via the session.
This is a helper for delivering queued task messages.
"""
await session.send_message(message)
async def handle(
self,
request: GetTaskPayloadRequest,
session: ServerSession,
request_id: RequestId,
) -> GetTaskPayloadResult:
"""
Handle a tasks/result request.
This implements the dequeue-send-wait loop:
1. Dequeue all pending messages
2. Send each via transport with relatedRequestId = this request's ID
3. If task not terminal, wait for status change
4. Loop until task is terminal
5. Return final result
Args:
request: The GetTaskPayloadRequest
session: The server session for sending messages
request_id: The request ID for relatedRequestId routing
Returns:
GetTaskPayloadResult with the task's final payload
"""
task_id = request.params.taskId
while True:
task = await self._store.get_task(task_id)
if task is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task not found: {task_id}",
)
)
await self._deliver_queued_messages(task_id, session, request_id)
# If task is terminal, return result
if is_terminal(task.status):
result = await self._store.get_result(task_id)
# GetTaskPayloadResult is a Result with extra="allow"
# The stored result contains the actual payload data
# Per spec: tasks/result MUST include _meta with related-task metadata
related_task = RelatedTaskMetadata(taskId=task_id)
related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)}
if result is not None:
result_data = result.model_dump(by_alias=True)
existing_meta: dict[str, Any] = result_data.get("_meta") or {}
result_data["_meta"] = {**existing_meta, **related_task_meta}
return GetTaskPayloadResult.model_validate(result_data)
return GetTaskPayloadResult.model_validate({"_meta": related_task_meta})
# Wait for task update (status change or new messages)
await self._wait_for_task_update(task_id)
async def _deliver_queued_messages(
self,
task_id: str,
session: ServerSession,
request_id: RequestId,
) -> None:
"""
Dequeue and send all pending messages for a task.
Each message is sent via the session's write stream with
relatedRequestId set so responses route back to this stream.
"""
while True:
message = await self._queue.dequeue(task_id)
if message is None:
break
# If this is a request (not notification), wait for response
if message.type == "request" and message.resolver is not None:
# Store the resolver so we can route the response back
original_id = message.original_request_id
if original_id is not None:
self._pending_requests[original_id] = message.resolver
logger.debug("Delivering queued message for task %s: %s", task_id, message.type)
# Send the message with relatedRequestId for routing
session_message = SessionMessage(
message=JSONRPCMessage(message.message),
metadata=ServerMessageMetadata(related_request_id=request_id),
)
await self.send_message(session, session_message)
async def _wait_for_task_update(self, task_id: str) -> None:
"""
Wait for task to be updated (status change or new message).
Races between store update and queue message - first one wins.
"""
async with anyio.create_task_group() as tg:
async def wait_for_store() -> None:
try:
await self._store.wait_for_update(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()
async def wait_for_queue() -> None:
try:
await self._queue.wait_for_message(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()
tg.start_soon(wait_for_store)
tg.start_soon(wait_for_queue)
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
"""
Route a response back to the waiting resolver.
This is called when a response arrives for a queued request.
Args:
request_id: The request ID from the response
response: The response data
Returns:
True if response was routed, False if no pending request
"""
resolver = self._pending_requests.pop(request_id, None)
if resolver is not None and not resolver.done():
resolver.set_result(response)
return True
return False
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
"""
Route an error back to the waiting resolver.
Args:
request_id: The request ID from the error response
error: The error data
Returns:
True if error was routed, False if no pending request
"""
resolver = self._pending_requests.pop(request_id, None)
if resolver is not None and not resolver.done():
resolver.set_exception(McpError(error))
return True
return False

View File

@@ -0,0 +1,115 @@
"""
TaskSupport - Configuration for experimental task support.
This module provides the TaskSupport class which encapsulates all the
infrastructure needed for task-augmented requests: store, queue, and handler.
"""
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
import anyio
from anyio.abc import TaskGroup
from mcp.server.experimental.task_result_handler import TaskResultHandler
from mcp.server.session import ServerSession
from mcp.shared.experimental.tasks.in_memory_task_store import InMemoryTaskStore
from mcp.shared.experimental.tasks.message_queue import InMemoryTaskMessageQueue, TaskMessageQueue
from mcp.shared.experimental.tasks.store import TaskStore
@dataclass
class TaskSupport:
"""
Configuration for experimental task support.
Encapsulates the task store, message queue, result handler, and task group
for spawning background work.
When enabled on a server, this automatically:
- Configures response routing for each session
- Provides default handlers for task operations
- Manages a task group for background task execution
Example:
# Simple in-memory setup
server.experimental.enable_tasks()
# Custom store/queue for distributed systems
server.experimental.enable_tasks(
store=RedisTaskStore(redis_url),
queue=RedisTaskMessageQueue(redis_url),
)
"""
store: TaskStore
queue: TaskMessageQueue
handler: TaskResultHandler = field(init=False)
_task_group: TaskGroup | None = field(init=False, default=None)
def __post_init__(self) -> None:
"""Create the result handler from store and queue."""
self.handler = TaskResultHandler(self.store, self.queue)
@property
def task_group(self) -> TaskGroup:
"""Get the task group for spawning background work.
Raises:
RuntimeError: If not within a run() context
"""
if self._task_group is None:
raise RuntimeError("TaskSupport not running. Ensure Server.run() is active.")
return self._task_group
@asynccontextmanager
async def run(self) -> AsyncIterator[None]:
"""
Run the task support lifecycle.
This creates a task group for spawning background task work.
Called automatically by Server.run().
Usage:
async with task_support.run():
# Task group is now available
...
"""
async with anyio.create_task_group() as tg:
self._task_group = tg
try:
yield
finally:
self._task_group = None
def configure_session(self, session: ServerSession) -> None:
"""
Configure a session for task support.
This registers the result handler as a response router so that
responses to queued requests (elicitation, sampling) are routed
back to the waiting resolvers.
Called automatically by Server.run() for each new session.
Args:
session: The session to configure
"""
session.add_response_router(self.handler)
@classmethod
def in_memory(cls) -> "TaskSupport":
"""
Create in-memory task support.
Suitable for development, testing, and single-process servers.
For distributed systems, provide custom store and queue implementations.
Returns:
TaskSupport configured with in-memory store and queue
"""
return cls(
store=InMemoryTaskStore(),
queue=InMemoryTaskMessageQueue(),
)

View File

@@ -0,0 +1,11 @@
"""FastMCP - A more ergonomic interface for MCP servers."""
from importlib.metadata import version
from mcp.types import Icon
from .server import Context, FastMCP
from .utilities.types import Audio, Image
__version__ = version("mcp")
__all__ = ["FastMCP", "Context", "Image", "Audio", "Icon"]

View File

@@ -0,0 +1,21 @@
"""Custom exceptions for FastMCP."""
class FastMCPError(Exception):
"""Base error for FastMCP."""
class ValidationError(FastMCPError):
"""Error in validating parameters or return values."""
class ResourceError(FastMCPError):
"""Error in resource operations."""
class ToolError(FastMCPError):
"""Error in tool operations."""
class InvalidSignature(Exception):
"""Invalid signature for use with FastMCP."""

View File

@@ -0,0 +1,4 @@
from .base import Prompt
from .manager import PromptManager
__all__ = ["Prompt", "PromptManager"]

View File

@@ -0,0 +1,183 @@
"""Base classes for FastMCP prompts."""
from __future__ import annotations
import inspect
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from mcp.types import ContentBlock, Icon, TextContent
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT
class Message(BaseModel):
"""Base class for all prompt messages."""
role: Literal["user", "assistant"]
content: ContentBlock
def __init__(self, content: str | ContentBlock, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
class UserMessage(Message):
"""A message from the user."""
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | ContentBlock, **kwargs: Any):
super().__init__(content=content, **kwargs)
class AssistantMessage(Message):
"""A message from the assistant."""
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | ContentBlock, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage)
SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
PromptResult = SyncPromptResult | Awaitable[SyncPromptResult]
class PromptArgument(BaseModel):
"""An argument that can be passed to a prompt."""
name: str = Field(description="Name of the argument")
description: str | None = Field(None, description="Description of what the argument does")
required: bool = Field(default=False, description="Whether the argument is required")
class Prompt(BaseModel):
"""A prompt template that can be rendered with parameters."""
name: str = Field(description="Name of the prompt")
title: str | None = Field(None, description="Human-readable title of the prompt")
description: str | None = Field(None, description="Description of what the prompt does")
arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt")
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this prompt")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context", exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
title: str | None = None,
description: str | None = None,
icons: list[Icon] | None = None,
context_kwarg: str | None = None,
) -> Prompt:
"""Create a Prompt from a function.
The function can return:
- A string (converted to a message)
- A Message object
- A dict (converted to a message)
- A sequence of any of the above
"""
func_name = name or fn.__name__
if func_name == "<lambda>": # pragma: no cover
raise ValueError("You must provide a name for lambda functions")
# Find context parameter if it exists
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)
# Get schema from func_metadata, excluding context parameter
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()
# Convert parameters to PromptArguments
arguments: list[PromptArgument] = []
if "properties" in parameters: # pragma: no branch
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
arguments.append(
PromptArgument(
name=param_name,
description=param.get("description"),
required=required,
)
)
# ensure the arguments are properly cast
fn = validate_call(fn)
return cls(
name=func_name,
title=title,
description=description or fn.__doc__ or "",
arguments=arguments,
fn=fn,
icons=icons,
context_kwarg=context_kwarg,
)
async def render(
self,
arguments: dict[str, Any] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> list[Message]:
"""Render the prompt with arguments."""
# Validate required arguments
if self.arguments:
required = {arg.name for arg in self.arguments if arg.required}
provided = set(arguments or {})
missing = required - provided
if missing:
raise ValueError(f"Missing required arguments: {missing}")
try:
# Add context to arguments if needed
call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg)
# Call function and check if result is a coroutine
result = self.fn(**call_args)
if inspect.iscoroutine(result):
result = await result
# Validate messages
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else: # pragma: no cover
content = pydantic_core.to_json(msg, fallback=str, indent=2).decode()
messages.append(Message(role="user", content=content))
except Exception: # pragma: no cover
raise ValueError(f"Could not convert prompt result to message: {msg}")
return messages
except Exception as e: # pragma: no cover
raise ValueError(f"Error rendering prompt {self.name}: {e}")

View File

@@ -0,0 +1,60 @@
"""Prompt management functionality."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.prompts.base import Message, Prompt
from mcp.server.fastmcp.utilities.logging import get_logger
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT
logger = get_logger(__name__)
class PromptManager:
"""Manages FastMCP prompts."""
def __init__(self, warn_on_duplicate_prompts: bool = True):
self._prompts: dict[str, Prompt] = {}
self.warn_on_duplicate_prompts = warn_on_duplicate_prompts
def get_prompt(self, name: str) -> Prompt | None:
"""Get prompt by name."""
return self._prompts.get(name)
def list_prompts(self) -> list[Prompt]:
"""List all registered prompts."""
return list(self._prompts.values())
def add_prompt(
self,
prompt: Prompt,
) -> Prompt:
"""Add a prompt to the manager."""
# Check for duplicates
existing = self._prompts.get(prompt.name)
if existing:
if self.warn_on_duplicate_prompts:
logger.warning(f"Prompt already exists: {prompt.name}")
return existing
self._prompts[prompt.name] = prompt
return prompt
async def render_prompt(
self,
name: str,
arguments: dict[str, Any] | None = None,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> list[Message]:
"""Render a prompt by name with arguments."""
prompt = self.get_prompt(name)
if not prompt:
raise ValueError(f"Unknown prompt: {name}")
return await prompt.render(arguments, context=context)

View File

@@ -0,0 +1,23 @@
from .base import Resource
from .resource_manager import ResourceManager
from .templates import ResourceTemplate
from .types import (
BinaryResource,
DirectoryResource,
FileResource,
FunctionResource,
HttpResource,
TextResource,
)
__all__ = [
"Resource",
"TextResource",
"BinaryResource",
"FunctionResource",
"FileResource",
"HttpResource",
"DirectoryResource",
"ResourceTemplate",
"ResourceManager",
]

View File

@@ -0,0 +1,49 @@
"""Base classes and interfaces for FastMCP resources."""
import abc
from typing import Annotated
from pydantic import (
AnyUrl,
BaseModel,
ConfigDict,
Field,
UrlConstraints,
ValidationInfo,
field_validator,
)
from mcp.types import Annotations, Icon
class Resource(BaseModel, abc.ABC):
"""Base class for all resources."""
model_config = ConfigDict(validate_default=True)
uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource")
name: str | None = Field(description="Name of the resource", default=None)
title: str | None = Field(description="Human-readable title of the resource", default=None)
description: str | None = Field(description="Description of the resource", default=None)
mime_type: str = Field(
default="text/plain",
description="MIME type of the resource content",
pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+(;\s*[a-zA-Z0-9\-_.]+=[a-zA-Z0-9\-_.]+)*$",
)
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this resource")
annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource")
@field_validator("name", mode="before")
@classmethod
def set_default_name(cls, name: str | None, info: ValidationInfo) -> str:
"""Set default name from URI if not provided."""
if name:
return name
if uri := info.data.get("uri"):
return str(uri)
raise ValueError("Either name or uri must be provided")
@abc.abstractmethod
async def read(self) -> str | bytes:
"""Read the resource content."""
pass # pragma: no cover

View File

@@ -0,0 +1,113 @@
"""Resource manager functionality."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import AnyUrl
from mcp.server.fastmcp.resources.base import Resource
from mcp.server.fastmcp.resources.templates import ResourceTemplate
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.types import Annotations, Icon
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT
logger = get_logger(__name__)
class ResourceManager:
"""Manages FastMCP resources."""
def __init__(self, warn_on_duplicate_resources: bool = True):
self._resources: dict[str, Resource] = {}
self._templates: dict[str, ResourceTemplate] = {}
self.warn_on_duplicate_resources = warn_on_duplicate_resources
def add_resource(self, resource: Resource) -> Resource:
"""Add a resource to the manager.
Args:
resource: A Resource instance to add
Returns:
The added resource. If a resource with the same URI already exists,
returns the existing resource.
"""
logger.debug(
"Adding resource",
extra={
"uri": resource.uri,
"type": type(resource).__name__,
"resource_name": resource.name,
},
)
existing = self._resources.get(str(resource.uri))
if existing:
if self.warn_on_duplicate_resources:
logger.warning(f"Resource already exists: {resource.uri}")
return existing
self._resources[str(resource.uri)] = resource
return resource
def add_template(
self,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
title: str | None = None,
description: str | None = None,
mime_type: str | None = None,
icons: list[Icon] | None = None,
annotations: Annotations | None = None,
) -> ResourceTemplate:
"""Add a template from a function."""
template = ResourceTemplate.from_function(
fn,
uri_template=uri_template,
name=name,
title=title,
description=description,
mime_type=mime_type,
icons=icons,
annotations=annotations,
)
self._templates[template.uri_template] = template
return template
async def get_resource(
self,
uri: AnyUrl | str,
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Resource | None:
"""Get resource by URI, checking concrete resources first, then templates."""
uri_str = str(uri)
logger.debug("Getting resource", extra={"uri": uri_str})
# First check concrete resources
if resource := self._resources.get(uri_str):
return resource
# Then check templates
for template in self._templates.values():
if params := template.matches(uri_str):
try:
return await template.create_resource(uri_str, params, context=context)
except Exception as e: # pragma: no cover
raise ValueError(f"Error creating resource from template: {e}")
raise ValueError(f"Unknown resource: {uri}")
def list_resources(self) -> list[Resource]:
"""List all registered resources."""
logger.debug("Listing resources", extra={"count": len(self._resources)})
return list(self._resources.values())
def list_templates(self) -> list[ResourceTemplate]:
"""List all registered templates."""
logger.debug("Listing templates", extra={"count": len(self._templates)})
return list(self._templates.values())

View File

@@ -0,0 +1,118 @@
"""Resource template functionality."""
from __future__ import annotations
import inspect
import re
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field, validate_call
from mcp.server.fastmcp.resources.types import FunctionResource, Resource
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter, inject_context
from mcp.server.fastmcp.utilities.func_metadata import func_metadata
from mcp.types import Annotations, Icon
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT
class ResourceTemplate(BaseModel):
"""A template for dynamically creating resources."""
uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)")
name: str = Field(description="Name of the resource")
title: str | None = Field(description="Human-readable title of the resource", default=None)
description: str | None = Field(description="Description of what the resource does")
mime_type: str = Field(default="text/plain", description="MIME type of the resource content")
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for the resource template")
annotations: Annotations | None = Field(default=None, description="Optional annotations for the resource template")
fn: Callable[..., Any] = Field(exclude=True)
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
@classmethod
def from_function(
cls,
fn: Callable[..., Any],
uri_template: str,
name: str | None = None,
title: str | None = None,
description: str | None = None,
mime_type: str | None = None,
icons: list[Icon] | None = None,
annotations: Annotations | None = None,
context_kwarg: str | None = None,
) -> ResourceTemplate:
"""Create a template from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions") # pragma: no cover
# Find context parameter if it exists
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)
# Get schema from func_metadata, excluding context parameter
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
)
parameters = func_arg_metadata.arg_model.model_json_schema()
# ensure the arguments are properly cast
fn = validate_call(fn)
return cls(
uri_template=uri_template,
name=func_name,
title=title,
description=description or fn.__doc__ or "",
mime_type=mime_type or "text/plain",
icons=icons,
annotations=annotations,
fn=fn,
parameters=parameters,
context_kwarg=context_kwarg,
)
def matches(self, uri: str) -> dict[str, Any] | None:
"""Check if URI matches template and extract parameters."""
# Convert template to regex pattern
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
match = re.match(f"^{pattern}$", uri)
if match:
return match.groupdict()
return None
async def create_resource(
self,
uri: str,
params: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
) -> Resource:
"""Create a resource from the template with the given parameters."""
try:
# Add context to params if needed
params = inject_context(self.fn, params, context, self.context_kwarg)
# Call function and check if result is a coroutine
result = self.fn(**params)
if inspect.iscoroutine(result):
result = await result
return FunctionResource(
uri=uri, # type: ignore
name=self.name,
title=self.title,
description=self.description,
mime_type=self.mime_type,
icons=self.icons,
annotations=self.annotations,
fn=lambda: result, # Capture result in closure
)
except Exception as e:
raise ValueError(f"Error creating resource from template: {e}")

View File

@@ -0,0 +1,201 @@
"""Concrete resource implementations."""
import inspect
import json
from collections.abc import Callable
from pathlib import Path
from typing import Any
import anyio
import anyio.to_thread
import httpx
import pydantic
import pydantic_core
from pydantic import AnyUrl, Field, ValidationInfo, validate_call
from mcp.server.fastmcp.resources.base import Resource
from mcp.types import Annotations, Icon
class TextResource(Resource):
"""A resource that reads from a string."""
text: str = Field(description="Text content of the resource")
async def read(self) -> str:
"""Read the text content."""
return self.text # pragma: no cover
class BinaryResource(Resource):
"""A resource that reads from bytes."""
data: bytes = Field(description="Binary content of the resource")
async def read(self) -> bytes:
"""Read the binary content."""
return self.data # pragma: no cover
class FunctionResource(Resource):
"""A resource that defers data loading by wrapping a function.
The function is only called when the resource is read, allowing for lazy loading
of potentially expensive data. This is particularly useful when listing resources,
as the function won't be called until the resource is actually accessed.
The function can return:
- str for text content (default)
- bytes for binary content
- other types will be converted to JSON
"""
fn: Callable[[], Any] = Field(exclude=True)
async def read(self) -> str | bytes:
"""Read the resource by calling the wrapped function."""
try:
# Call the function first to see if it returns a coroutine
result = self.fn()
# If it's a coroutine, await it
if inspect.iscoroutine(result):
result = await result
if isinstance(result, Resource): # pragma: no cover
return await result.read()
elif isinstance(result, bytes):
return result
elif isinstance(result, str):
return result
else:
return pydantic_core.to_json(result, fallback=str, indent=2).decode()
except Exception as e:
raise ValueError(f"Error reading resource {self.uri}: {e}")
@classmethod
def from_function(
cls,
fn: Callable[..., Any],
uri: str,
name: str | None = None,
title: str | None = None,
description: str | None = None,
mime_type: str | None = None,
icons: list[Icon] | None = None,
annotations: Annotations | None = None,
) -> "FunctionResource":
"""Create a FunctionResource from a function."""
func_name = name or fn.__name__
if func_name == "<lambda>": # pragma: no cover
raise ValueError("You must provide a name for lambda functions")
# ensure the arguments are properly cast
fn = validate_call(fn)
return cls(
uri=AnyUrl(uri),
name=func_name,
title=title,
description=description or fn.__doc__ or "",
mime_type=mime_type or "text/plain",
fn=fn,
icons=icons,
annotations=annotations,
)
class FileResource(Resource):
"""A resource that reads from a file.
Set is_binary=True to read file as binary data instead of text.
"""
path: Path = Field(description="Path to the file")
is_binary: bool = Field(
default=False,
description="Whether to read the file as binary data",
)
mime_type: str = Field(
default="text/plain",
description="MIME type of the resource content",
)
@pydantic.field_validator("path")
@classmethod
def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover
"""Ensure path is absolute."""
if not path.is_absolute():
raise ValueError("Path must be absolute")
return path
@pydantic.field_validator("is_binary")
@classmethod
def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool:
"""Set is_binary based on mime_type if not explicitly set."""
if is_binary:
return True
mime_type = info.data.get("mime_type", "text/plain")
return not mime_type.startswith("text/")
async def read(self) -> str | bytes:
"""Read the file content."""
try:
if self.is_binary:
return await anyio.to_thread.run_sync(self.path.read_bytes)
return await anyio.to_thread.run_sync(self.path.read_text)
except Exception as e:
raise ValueError(f"Error reading file {self.path}: {e}")
class HttpResource(Resource):
"""A resource that reads from an HTTP endpoint."""
url: str = Field(description="URL to fetch content from")
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
async def read(self) -> str | bytes:
"""Read the HTTP content."""
async with httpx.AsyncClient() as client: # pragma: no cover
response = await client.get(self.url)
response.raise_for_status()
return response.text
class DirectoryResource(Resource):
"""A resource that lists files in a directory."""
path: Path = Field(description="Path to the directory")
recursive: bool = Field(default=False, description="Whether to list files recursively")
pattern: str | None = Field(default=None, description="Optional glob pattern to filter files")
mime_type: str = Field(default="application/json", description="MIME type of the resource content")
@pydantic.field_validator("path")
@classmethod
def validate_absolute_path(cls, path: Path) -> Path: # pragma: no cover
"""Ensure path is absolute."""
if not path.is_absolute():
raise ValueError("Path must be absolute")
return path
def list_files(self) -> list[Path]: # pragma: no cover
"""List files in the directory."""
if not self.path.exists():
raise FileNotFoundError(f"Directory not found: {self.path}")
if not self.path.is_dir():
raise NotADirectoryError(f"Not a directory: {self.path}")
try:
if self.pattern:
return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern))
return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*"))
except Exception as e:
raise ValueError(f"Error listing directory {self.path}: {e}")
async def read(self) -> str: # Always returns JSON string # pragma: no cover
"""Read the directory listing."""
try:
files = await anyio.to_thread.run_sync(self.list_files)
file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()]
return json.dumps({"files": file_list}, indent=2)
except Exception as e:
raise ValueError(f"Error reading directory {self.path}: {e}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,4 @@
from .base import Tool
from .tool_manager import ToolManager
__all__ = ["Tool", "ToolManager"]

View File

@@ -0,0 +1,126 @@
from __future__ import annotations as _annotations
import functools
import inspect
from collections.abc import Callable
from functools import cached_property
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.utilities.context_injection import find_context_parameter
from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata
from mcp.shared.exceptions import UrlElicitationRequiredError
from mcp.shared.tool_name_validation import validate_and_warn_tool_name
from mcp.types import Icon, ToolAnnotations
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
from mcp.shared.context import LifespanContextT, RequestT
class Tool(BaseModel):
"""Internal tool registration info."""
fn: Callable[..., Any] = Field(exclude=True)
name: str = Field(description="Name of the tool")
title: str | None = Field(None, description="Human-readable title of the tool")
description: str = Field(description="Description of what the tool does")
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
fn_metadata: FuncMetadata = Field(
description="Metadata about the function including a pydantic model for tool arguments"
)
is_async: bool = Field(description="Whether the tool is async")
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool")
icons: list[Icon] | None = Field(default=None, description="Optional list of icons for this tool")
meta: dict[str, Any] | None = Field(default=None, description="Optional metadata for this tool")
@cached_property
def output_schema(self) -> dict[str, Any] | None:
return self.fn_metadata.output_schema
@classmethod
def from_function(
cls,
fn: Callable[..., Any],
name: str | None = None,
title: str | None = None,
description: str | None = None,
context_kwarg: str | None = None,
annotations: ToolAnnotations | None = None,
icons: list[Icon] | None = None,
meta: dict[str, Any] | None = None,
structured_output: bool | None = None,
) -> Tool:
"""Create a Tool from a function."""
func_name = name or fn.__name__
validate_and_warn_tool_name(func_name)
if func_name == "<lambda>":
raise ValueError("You must provide a name for lambda functions")
func_doc = description or fn.__doc__ or ""
is_async = _is_async_callable(fn)
if context_kwarg is None: # pragma: no branch
context_kwarg = find_context_parameter(fn)
func_arg_metadata = func_metadata(
fn,
skip_names=[context_kwarg] if context_kwarg is not None else [],
structured_output=structured_output,
)
parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True)
return cls(
fn=fn,
name=func_name,
title=title,
description=func_doc,
parameters=parameters,
fn_metadata=func_arg_metadata,
is_async=is_async,
context_kwarg=context_kwarg,
annotations=annotations,
icons=icons,
meta=meta,
)
async def run(
self,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
convert_result: bool = False,
) -> Any:
"""Run the tool with arguments."""
try:
result = await self.fn_metadata.call_fn_with_arg_validation(
self.fn,
self.is_async,
arguments,
{self.context_kwarg: context} if self.context_kwarg is not None else None,
)
if convert_result:
result = self.fn_metadata.convert_result(result)
return result
except UrlElicitationRequiredError:
# Re-raise UrlElicitationRequiredError so it can be properly handled
# as an MCP error response with code -32042
raise
except Exception as e:
raise ToolError(f"Error executing tool {self.name}: {e}") from e
def _is_async_callable(obj: Any) -> bool:
while isinstance(obj, functools.partial): # pragma: no cover
obj = obj.func
return inspect.iscoroutinefunction(obj) or (
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
)

View File

@@ -0,0 +1,93 @@
from __future__ import annotations as _annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from mcp.server.fastmcp.exceptions import ToolError
from mcp.server.fastmcp.tools.base import Tool
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.shared.context import LifespanContextT, RequestT
from mcp.types import Icon, ToolAnnotations
if TYPE_CHECKING:
from mcp.server.fastmcp.server import Context
from mcp.server.session import ServerSessionT
logger = get_logger(__name__)
class ToolManager:
"""Manages FastMCP tools."""
def __init__(
self,
warn_on_duplicate_tools: bool = True,
*,
tools: list[Tool] | None = None,
):
self._tools: dict[str, Tool] = {}
if tools is not None:
for tool in tools:
if warn_on_duplicate_tools and tool.name in self._tools:
logger.warning(f"Tool already exists: {tool.name}")
self._tools[tool.name] = tool
self.warn_on_duplicate_tools = warn_on_duplicate_tools
def get_tool(self, name: str) -> Tool | None:
"""Get tool by name."""
return self._tools.get(name)
def list_tools(self) -> list[Tool]:
"""List all registered tools."""
return list(self._tools.values())
def add_tool(
self,
fn: Callable[..., Any],
name: str | None = None,
title: str | None = None,
description: str | None = None,
annotations: ToolAnnotations | None = None,
icons: list[Icon] | None = None,
meta: dict[str, Any] | None = None,
structured_output: bool | None = None,
) -> Tool:
"""Add a tool to the server."""
tool = Tool.from_function(
fn,
name=name,
title=title,
description=description,
annotations=annotations,
icons=icons,
meta=meta,
structured_output=structured_output,
)
existing = self._tools.get(tool.name)
if existing:
if self.warn_on_duplicate_tools:
logger.warning(f"Tool already exists: {tool.name}")
return existing
self._tools[tool.name] = tool
return tool
def remove_tool(self, name: str) -> None:
"""Remove a tool by name."""
if name not in self._tools:
raise ToolError(f"Unknown tool: {name}")
del self._tools[name]
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
convert_result: bool = False,
) -> Any:
"""Call a tool by name with arguments."""
tool = self.get_tool(name)
if not tool:
raise ToolError(f"Unknown tool: {name}")
return await tool.run(arguments, context=context, convert_result=convert_result)

View File

@@ -0,0 +1 @@
"""FastMCP utility modules."""

View File

@@ -0,0 +1,68 @@
"""Context injection utilities for FastMCP."""
from __future__ import annotations
import inspect
import typing
from collections.abc import Callable
from typing import Any
def find_context_parameter(fn: Callable[..., Any]) -> str | None:
"""Find the parameter that should receive the Context object.
Searches through the function's signature to find a parameter
with a Context type annotation.
Args:
fn: The function to inspect
Returns:
The name of the context parameter, or None if not found
"""
from mcp.server.fastmcp.server import Context
# Get type hints to properly resolve string annotations
try:
hints = typing.get_type_hints(fn)
except Exception:
# If we can't resolve type hints, we can't find the context parameter
return None
# Check each parameter's type hint
for param_name, annotation in hints.items():
# Handle direct Context type
if inspect.isclass(annotation) and issubclass(annotation, Context):
return param_name
# Handle generic types like Optional[Context]
origin = typing.get_origin(annotation)
if origin is not None:
args = typing.get_args(annotation)
for arg in args:
if inspect.isclass(arg) and issubclass(arg, Context):
return param_name
return None
def inject_context(
fn: Callable[..., Any],
kwargs: dict[str, Any],
context: Any | None,
context_kwarg: str | None,
) -> dict[str, Any]:
"""Inject context into function kwargs if needed.
Args:
fn: The function that will be called
kwargs: The current keyword arguments
context: The context object to inject (if any)
context_kwarg: The name of the parameter to inject into
Returns:
Updated kwargs with context injected if applicable
"""
if context_kwarg is not None and context is not None:
return {**kwargs, context_kwarg: context}
return kwargs

View File

@@ -0,0 +1,533 @@
import inspect
import json
from collections.abc import Awaitable, Callable, Sequence
from itertools import chain
from types import GenericAlias
from typing import Annotated, Any, cast, get_args, get_origin, get_type_hints
import pydantic_core
from pydantic import (
BaseModel,
ConfigDict,
Field,
RootModel,
WithJsonSchema,
create_model,
)
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind
from typing_extensions import is_typeddict
from typing_inspection.introspection import (
UNKNOWN,
AnnotationSource,
ForbiddenQualifier,
inspect_annotation,
is_union_origin,
)
from mcp.server.fastmcp.exceptions import InvalidSignature
from mcp.server.fastmcp.utilities.logging import get_logger
from mcp.server.fastmcp.utilities.types import Audio, Image
from mcp.types import CallToolResult, ContentBlock, TextContent
logger = get_logger(__name__)
class StrictJsonSchema(GenerateJsonSchema):
"""A JSON schema generator that raises exceptions instead of emitting warnings.
This is used to detect non-serializable types during schema generation.
"""
def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None:
# Raise an exception instead of emitting a warning
raise ValueError(f"JSON schema warning: {kind} - {detail}")
class ArgModelBase(BaseModel):
"""A model representing the arguments to a function."""
def model_dump_one_level(self) -> dict[str, Any]:
"""Return a dict of the model's fields, one level deep.
That is, sub-models etc are not dumped - they are kept as pydantic models.
"""
kwargs: dict[str, Any] = {}
for field_name, field_info in self.__class__.model_fields.items():
value = getattr(self, field_name)
# Use the alias if it exists, otherwise use the field name
output_name = field_info.alias if field_info.alias else field_name
kwargs[output_name] = value
return kwargs
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
class FuncMetadata(BaseModel):
arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)]
output_schema: dict[str, Any] | None = None
output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None
wrap_output: bool = False
async def call_fn_with_arg_validation(
self,
fn: Callable[..., Any | Awaitable[Any]],
fn_is_async: bool,
arguments_to_validate: dict[str, Any],
arguments_to_pass_directly: dict[str, Any] | None,
) -> Any:
"""Call the given function with arguments validated and injected.
Arguments are first attempted to be parsed from JSON, then validated against
the argument model, before being passed to the function.
"""
arguments_pre_parsed = self.pre_parse_json(arguments_to_validate)
arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed)
arguments_parsed_dict = arguments_parsed_model.model_dump_one_level()
arguments_parsed_dict |= arguments_to_pass_directly or {}
if fn_is_async:
return await fn(**arguments_parsed_dict)
else:
return fn(**arguments_parsed_dict)
def convert_result(self, result: Any) -> Any:
"""
Convert the result of a function call to the appropriate format for
the lowlevel server tool call handler:
- If output_model is None, return the unstructured content directly.
- If output_model is not None, convert the result to structured output format
(dict[str, Any]) and return both unstructured and structured content.
Note: we return unstructured content here **even though the lowlevel server
tool call handler provides generic backwards compatibility serialization of
structured content**. This is for FastMCP backwards compatibility: we need to
retain FastMCP's ad hoc conversion logic for constructing unstructured output
from function return values, whereas the lowlevel server simply serializes
the structured output.
"""
if isinstance(result, CallToolResult):
if self.output_schema is not None:
assert self.output_model is not None, "Output model must be set if output schema is defined"
self.output_model.model_validate(result.structuredContent)
return result
unstructured_content = _convert_to_content(result)
if self.output_schema is None:
return unstructured_content
else:
if self.wrap_output:
result = {"result": result}
assert self.output_model is not None, "Output model must be set if output schema is defined"
validated = self.output_model.model_validate(result)
structured_content = validated.model_dump(mode="json", by_alias=True)
return (unstructured_content, structured_content)
def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
"""Pre-parse data from JSON.
Return a dict with same keys as input but with values parsed from JSON
if appropriate.
This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside
a string rather than an actual list. Claude desktop is prone to this - in fact
it seems incapable of NOT doing this. For sub-models, it tends to pass
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
"""
new_data = data.copy() # Shallow copy
# Build a mapping from input keys (including aliases) to field info
key_to_field_info: dict[str, FieldInfo] = {}
for field_name, field_info in self.arg_model.model_fields.items():
# Map both the field name and its alias (if any) to the field info
key_to_field_info[field_name] = field_info
if field_info.alias:
key_to_field_info[field_info.alias] = field_info
for data_key, data_value in data.items():
if data_key not in key_to_field_info: # pragma: no cover
continue
field_info = key_to_field_info[data_key]
if isinstance(data_value, str) and field_info.annotation is not str:
try:
pre_parsed = json.loads(data_value)
except json.JSONDecodeError:
continue # Not JSON - skip
if isinstance(pre_parsed, str | int | float):
# This is likely that the raw value is e.g. `"hello"` which we
# Should really be parsed as '"hello"' in Python - but if we parse
# it as JSON it'll turn into just 'hello'. So we skip it.
continue
new_data[data_key] = pre_parsed
assert new_data.keys() == data.keys()
return new_data
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def func_metadata(
func: Callable[..., Any],
skip_names: Sequence[str] = (),
structured_output: bool | None = None,
) -> FuncMetadata:
"""Given a function, return metadata including a pydantic model representing its
signature.
The use case for this is
```
meta = func_metadata(func)
validated_args = meta.arg_model.model_validate(some_raw_data_dict)
return func(**validated_args.model_dump_one_level())
```
**critically** it also provides pre-parse helper to attempt to parse things from
JSON.
Args:
func: The function to convert to a pydantic model
skip_names: A list of parameter names to skip. These will not be included in
the model.
structured_output: Controls whether the tool's output is structured or unstructured
- If None, auto-detects based on the function's return type annotation
- If True, creates a structured tool (return type annotation permitting)
- If False, unconditionally creates an unstructured tool
If structured, creates a Pydantic model for the function's result based on its annotation.
Supports various return types:
- BaseModel subclasses (used directly)
- Primitive types (str, int, float, bool, bytes, None) - wrapped in a
model with a 'result' field
- TypedDict - converted to a Pydantic model with same fields
- Dataclasses and other annotated classes - converted to Pydantic models
- Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field
Returns:
A FuncMetadata object containing:
- arg_model: A pydantic model representing the function's arguments
- output_model: A pydantic model for the return type if output is structured
- output_conversion: Records how function output should be converted before returning.
"""
try:
sig = inspect.signature(func, eval_str=True)
except NameError as e: # pragma: no cover
# This raise could perhaps be skipped, and we (FastMCP) just call
# model_rebuild right before using it 🤷
raise InvalidSignature(f"Unable to evaluate type annotations for callable {func.__name__!r}") from e
params = sig.parameters
dynamic_pydantic_model_params: dict[str, Any] = {}
for param in params.values():
if param.name.startswith("_"): # pragma: no cover
raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'")
if param.name in skip_names:
continue
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
field_name = param.name
field_kwargs: dict[str, Any] = {}
field_metadata: list[Any] = []
if param.annotation is inspect.Parameter.empty:
field_metadata.append(WithJsonSchema({"title": param.name, "type": "string"}))
# Check if the parameter name conflicts with BaseModel attributes
# This is necessary because Pydantic warns about shadowing parent attributes
if hasattr(BaseModel, field_name) and callable(getattr(BaseModel, field_name)):
# Use an alias to avoid the shadowing warning
field_kwargs["alias"] = field_name
# Use a prefixed field name
field_name = f"field_{field_name}"
if param.default is not inspect.Parameter.empty:
dynamic_pydantic_model_params[field_name] = (
Annotated[(annotation, *field_metadata, Field(**field_kwargs))],
param.default,
)
else:
dynamic_pydantic_model_params[field_name] = Annotated[(annotation, *field_metadata, Field(**field_kwargs))]
arguments_model = create_model(
f"{func.__name__}Arguments",
__base__=ArgModelBase,
**dynamic_pydantic_model_params,
)
if structured_output is False:
return FuncMetadata(arg_model=arguments_model)
# set up structured output support based on return type annotation
if sig.return_annotation is inspect.Parameter.empty and structured_output is True:
raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output")
try:
inspected_return_ann = inspect_annotation(sig.return_annotation, annotation_source=AnnotationSource.FUNCTION)
except ForbiddenQualifier as e:
raise InvalidSignature(f"Function {func.__name__}: return annotation contains an invalid type qualifier") from e
return_type_expr = inspected_return_ann.type
# `AnnotationSource.FUNCTION` allows no type qualifier to be used, so `return_type_expr` is guaranteed to *not* be
# unknown (i.e. a bare `Final`).
assert return_type_expr is not UNKNOWN
if is_union_origin(get_origin(return_type_expr)):
args = get_args(return_type_expr)
# Check if CallToolResult appears in the union (excluding None for Optional check)
if any(isinstance(arg, type) and issubclass(arg, CallToolResult) for arg in args if arg is not type(None)):
raise InvalidSignature(
f"Function {func.__name__}: CallToolResult cannot be used in Union or Optional types. "
"To return empty results, use: CallToolResult(content=[])"
)
original_annotation: Any
# if the typehint is CallToolResult, the user either intends to return without validation
# or they provided validation as Annotated metadata
if isinstance(return_type_expr, type) and issubclass(return_type_expr, CallToolResult):
if inspected_return_ann.metadata:
return_type_expr = inspected_return_ann.metadata[0]
if len(inspected_return_ann.metadata) >= 2:
# Reconstruct the original annotation, by preserving the remaining metadata,
# i.e. from `Annotated[CallToolResult, ReturnType, Gt(1)]` to
# `Annotated[ReturnType, Gt(1)]`:
original_annotation = Annotated[
(return_type_expr, *inspected_return_ann.metadata[1:])
] # pragma: no cover
else:
# We only had `Annotated[CallToolResult, ReturnType]`, treat the original annotation
# as beging `ReturnType`:
original_annotation = return_type_expr
else:
return FuncMetadata(arg_model=arguments_model)
else:
original_annotation = sig.return_annotation
output_model, output_schema, wrap_output = _try_create_model_and_schema(
original_annotation, return_type_expr, func.__name__
)
if output_model is None and structured_output is True:
# Model creation failed or produced warnings - no structured output
raise InvalidSignature(
f"Function {func.__name__}: return type {return_type_expr} is not serializable for structured output"
)
return FuncMetadata(
arg_model=arguments_model,
output_schema=output_schema,
output_model=output_model,
wrap_output=wrap_output,
)
def _try_create_model_and_schema(
original_annotation: Any,
type_expr: Any,
func_name: str,
) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]:
"""Try to create a model and schema for the given annotation without warnings.
Args:
original_annotation: The original return annotation (may be wrapped in `Annotated`).
type_expr: The underlying type expression derived from the return annotation
(`Annotated` and type qualifiers were stripped).
func_name: The name of the function.
Returns:
tuple of (model or None, schema or None, wrap_output)
Model and schema are None if warnings occur or creation fails.
wrap_output is True if the result needs to be wrapped in {"result": ...}
"""
model = None
wrap_output = False
# First handle special case: None
if type_expr is None:
model = _create_wrapped_model(func_name, original_annotation)
wrap_output = True
# Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.)
elif isinstance(type_expr, GenericAlias):
origin = get_origin(type_expr)
# Special case: dict with string keys can use RootModel
if origin is dict:
args = get_args(type_expr)
if len(args) == 2 and args[0] is str:
# TODO: should we use the original annotation? We are loosing any potential `Annotated`
# metadata for Pydantic here:
model = _create_dict_model(func_name, type_expr)
else:
# dict with non-str keys needs wrapping
model = _create_wrapped_model(func_name, original_annotation)
wrap_output = True
else:
# All other generic types need wrapping (list, tuple, Union, Optional, etc.)
model = _create_wrapped_model(func_name, original_annotation)
wrap_output = True
# Handle regular type objects
elif isinstance(type_expr, type):
type_annotation = cast(type[Any], type_expr)
# Case 1: BaseModel subclasses (can be used directly)
if issubclass(type_annotation, BaseModel):
model = type_annotation
# Case 2: TypedDicts:
elif is_typeddict(type_annotation):
model = _create_model_from_typeddict(type_annotation)
# Case 3: Primitive types that need wrapping
elif type_annotation in (str, int, float, bool, bytes, type(None)):
model = _create_wrapped_model(func_name, original_annotation)
wrap_output = True
# Case 4: Other class types (dataclasses, regular classes with annotations)
else:
type_hints = get_type_hints(type_annotation)
if type_hints:
# Classes with type hints can be converted to Pydantic models
model = _create_model_from_class(type_annotation, type_hints)
# Classes without type hints are not serializable - model remains None
# Handle any other types not covered above
else:
# This includes typing constructs that aren't GenericAlias in Python 3.10
# (e.g., Union, Optional in some Python versions)
model = _create_wrapped_model(func_name, original_annotation)
wrap_output = True
if model:
# If we successfully created a model, try to get its schema
# Use StrictJsonSchema to raise exceptions instead of warnings
try:
schema = model.model_json_schema(schema_generator=StrictJsonSchema)
except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e:
# These are expected errors when a type can't be converted to a Pydantic schema
# TypeError: When Pydantic can't handle the type
# ValueError: When there are issues with the type definition (including our custom warnings)
# SchemaError: When Pydantic can't build a schema
# ValidationError: When validation fails
logger.info(f"Cannot create schema for type {type_expr} in {func_name}: {type(e).__name__}: {e}")
return None, None, False
return model, schema, wrap_output
return None, None, False
_no_default = object()
def _create_model_from_class(cls: type[Any], type_hints: dict[str, Any]) -> type[BaseModel]:
"""Create a Pydantic model from an ordinary class.
The created model will:
- Have the same name as the class
- Have fields with the same names and types as the class's fields
- Include all fields whose type does not include None in the set of required fields
Precondition: cls must have type hints (i.e., `type_hints` is non-empty)
"""
model_fields: dict[str, Any] = {}
for field_name, field_type in type_hints.items():
if field_name.startswith("_"): # pragma: no cover
continue
default = getattr(cls, field_name, _no_default)
if default is _no_default:
model_fields[field_name] = field_type
else:
model_fields[field_name] = (field_type, default)
return create_model(cls.__name__, __config__=ConfigDict(from_attributes=True), **model_fields)
def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]:
"""Create a Pydantic model from a TypedDict.
The created model will have the same name and fields as the TypedDict.
"""
type_hints = get_type_hints(td_type)
required_keys = getattr(td_type, "__required_keys__", set(type_hints.keys()))
model_fields: dict[str, Any] = {}
for field_name, field_type in type_hints.items():
if field_name not in required_keys:
# For optional TypedDict fields, set default=None
# This makes them not required in the Pydantic model
# The model should use exclude_unset=True when dumping to get TypedDict semantics
model_fields[field_name] = (field_type, None)
else:
model_fields[field_name] = field_type
return create_model(td_type.__name__, **model_fields)
def _create_wrapped_model(func_name: str, annotation: Any) -> type[BaseModel]:
"""Create a model that wraps a type in a 'result' field.
This is used for primitive types, generic types like list/dict, etc.
"""
model_name = f"{func_name}Output"
return create_model(model_name, result=annotation)
def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]:
"""Create a RootModel for dict[str, T] types."""
class DictModel(RootModel[dict_annotation]):
pass
# Give it a meaningful name
DictModel.__name__ = f"{func_name}DictOutput"
DictModel.__qualname__ = f"{func_name}DictOutput"
return DictModel
def _convert_to_content(
result: Any,
) -> Sequence[ContentBlock]:
"""
Convert a result to a sequence of content objects.
Note: This conversion logic comes from previous versions of FastMCP and is being
retained for purposes of backwards compatibility. It produces different unstructured
output than the lowlevel server tool call handler, which just serializes structured
content verbatim.
"""
if result is None: # pragma: no cover
return []
if isinstance(result, ContentBlock):
return [result]
if isinstance(result, Image):
return [result.to_image_content()]
if isinstance(result, Audio):
return [result.to_audio_content()]
if isinstance(result, list | tuple):
return list(
chain.from_iterable(
_convert_to_content(item)
for item in result # type: ignore
)
)
if not isinstance(result, str):
result = pydantic_core.to_json(result, fallback=str, indent=2).decode()
return [TextContent(type="text", text=result)]

View File

@@ -0,0 +1,43 @@
"""Logging utilities for FastMCP."""
import logging
from typing import Literal
def get_logger(name: str) -> logging.Logger:
"""Get a logger nested under MCPnamespace.
Args:
name: the name of the logger, which will be prefixed with 'FastMCP.'
Returns:
a configured logger instance
"""
return logging.getLogger(name)
def configure_logging(
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
) -> None:
"""Configure logging for MCP.
Args:
level: the log level to use
"""
handlers: list[logging.Handler] = []
try: # pragma: no cover
from rich.console import Console
from rich.logging import RichHandler
handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True))
except ImportError: # pragma: no cover
pass
if not handlers: # pragma: no cover
handlers.append(logging.StreamHandler())
logging.basicConfig(
level=level,
format="%(message)s",
handlers=handlers,
)

View File

@@ -0,0 +1,101 @@
"""Common types used across FastMCP."""
import base64
from pathlib import Path
from mcp.types import AudioContent, ImageContent
class Image:
"""Helper class for returning images from tools."""
def __init__(
self,
path: str | Path | None = None,
data: bytes | None = None,
format: str | None = None,
):
if path is None and data is None: # pragma: no cover
raise ValueError("Either path or data must be provided")
if path is not None and data is not None: # pragma: no cover
raise ValueError("Only one of path or data can be provided")
self.path = Path(path) if path else None
self.data = data
self._format = format
self._mime_type = self._get_mime_type()
def _get_mime_type(self) -> str:
"""Get MIME type from format or guess from file extension."""
if self._format: # pragma: no cover
return f"image/{self._format.lower()}"
if self.path:
suffix = self.path.suffix.lower()
return {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
}.get(suffix, "application/octet-stream")
return "image/png" # pragma: no cover # default for raw binary data
def to_image_content(self) -> ImageContent:
"""Convert to MCP ImageContent."""
if self.path:
with open(self.path, "rb") as f:
data = base64.b64encode(f.read()).decode()
elif self.data is not None: # pragma: no cover
data = base64.b64encode(self.data).decode()
else: # pragma: no cover
raise ValueError("No image data available")
return ImageContent(type="image", data=data, mimeType=self._mime_type)
class Audio:
"""Helper class for returning audio from tools."""
def __init__(
self,
path: str | Path | None = None,
data: bytes | None = None,
format: str | None = None,
):
if not bool(path) ^ bool(data): # pragma: no cover
raise ValueError("Either path or data can be provided")
self.path = Path(path) if path else None
self.data = data
self._format = format
self._mime_type = self._get_mime_type()
def _get_mime_type(self) -> str:
"""Get MIME type from format or guess from file extension."""
if self._format: # pragma: no cover
return f"audio/{self._format.lower()}"
if self.path:
suffix = self.path.suffix.lower()
return {
".wav": "audio/wav",
".mp3": "audio/mpeg",
".ogg": "audio/ogg",
".flac": "audio/flac",
".aac": "audio/aac",
".m4a": "audio/mp4",
}.get(suffix, "application/octet-stream")
return "audio/wav" # pragma: no cover # default for raw binary data
def to_audio_content(self) -> AudioContent:
"""Convert to MCP AudioContent."""
if self.path:
with open(self.path, "rb") as f:
data = base64.b64encode(f.read()).decode()
elif self.data is not None: # pragma: no cover
data = base64.b64encode(self.data).decode()
else: # pragma: no cover
raise ValueError("No audio data available")
return AudioContent(type="audio", data=data, mimeType=self._mime_type)

View File

@@ -0,0 +1,3 @@
from .server import NotificationOptions, Server
__all__ = ["Server", "NotificationOptions"]

Some files were not shown because too many files have changed in this diff Show More