Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
135
.venv/lib/python3.11/site-packages/mcp/__init__.py
Normal file
135
.venv/lib/python3.11/site-packages/mcp/__init__.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from .client.session import ClientSession
|
||||
from .client.session_group import ClientSessionGroup
|
||||
from .client.stdio import StdioServerParameters, stdio_client
|
||||
from .server.session import ServerSession
|
||||
from .server.stdio import stdio_server
|
||||
from .shared.exceptions import McpError, UrlElicitationRequiredError
|
||||
from .types import (
|
||||
CallToolRequest,
|
||||
ClientCapabilities,
|
||||
ClientNotification,
|
||||
ClientRequest,
|
||||
ClientResult,
|
||||
CompleteRequest,
|
||||
CreateMessageRequest,
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
GetPromptRequest,
|
||||
GetPromptResult,
|
||||
Implementation,
|
||||
IncludeContext,
|
||||
InitializedNotification,
|
||||
InitializeRequest,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
ListPromptsRequest,
|
||||
ListPromptsResult,
|
||||
ListResourcesRequest,
|
||||
ListResourcesResult,
|
||||
ListToolsResult,
|
||||
LoggingLevel,
|
||||
LoggingMessageNotification,
|
||||
Notification,
|
||||
PingRequest,
|
||||
ProgressNotification,
|
||||
PromptsCapability,
|
||||
ReadResourceRequest,
|
||||
ReadResourceResult,
|
||||
Resource,
|
||||
ResourcesCapability,
|
||||
ResourceUpdatedNotification,
|
||||
RootsCapability,
|
||||
SamplingCapability,
|
||||
SamplingContent,
|
||||
SamplingContextCapability,
|
||||
SamplingMessage,
|
||||
SamplingMessageContentBlock,
|
||||
SamplingToolsCapability,
|
||||
ServerCapabilities,
|
||||
ServerNotification,
|
||||
ServerRequest,
|
||||
ServerResult,
|
||||
SetLevelRequest,
|
||||
StopReason,
|
||||
SubscribeRequest,
|
||||
Tool,
|
||||
ToolChoice,
|
||||
ToolResultContent,
|
||||
ToolsCapability,
|
||||
ToolUseContent,
|
||||
UnsubscribeRequest,
|
||||
)
|
||||
from .types import (
|
||||
Role as SamplingRole,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CallToolRequest",
|
||||
"ClientCapabilities",
|
||||
"ClientNotification",
|
||||
"ClientRequest",
|
||||
"ClientResult",
|
||||
"ClientSession",
|
||||
"ClientSessionGroup",
|
||||
"CompleteRequest",
|
||||
"CreateMessageRequest",
|
||||
"CreateMessageResult",
|
||||
"CreateMessageResultWithTools",
|
||||
"ErrorData",
|
||||
"GetPromptRequest",
|
||||
"GetPromptResult",
|
||||
"Implementation",
|
||||
"IncludeContext",
|
||||
"InitializeRequest",
|
||||
"InitializeResult",
|
||||
"InitializedNotification",
|
||||
"JSONRPCError",
|
||||
"JSONRPCRequest",
|
||||
"JSONRPCResponse",
|
||||
"ListPromptsRequest",
|
||||
"ListPromptsResult",
|
||||
"ListResourcesRequest",
|
||||
"ListResourcesResult",
|
||||
"ListToolsResult",
|
||||
"LoggingLevel",
|
||||
"LoggingMessageNotification",
|
||||
"McpError",
|
||||
"Notification",
|
||||
"PingRequest",
|
||||
"ProgressNotification",
|
||||
"PromptsCapability",
|
||||
"ReadResourceRequest",
|
||||
"ReadResourceResult",
|
||||
"Resource",
|
||||
"ResourcesCapability",
|
||||
"ResourceUpdatedNotification",
|
||||
"RootsCapability",
|
||||
"SamplingCapability",
|
||||
"SamplingContent",
|
||||
"SamplingContextCapability",
|
||||
"SamplingMessage",
|
||||
"SamplingMessageContentBlock",
|
||||
"SamplingRole",
|
||||
"SamplingToolsCapability",
|
||||
"ServerCapabilities",
|
||||
"ServerNotification",
|
||||
"ServerRequest",
|
||||
"ServerResult",
|
||||
"ServerSession",
|
||||
"SetLevelRequest",
|
||||
"StdioServerParameters",
|
||||
"StopReason",
|
||||
"SubscribeRequest",
|
||||
"Tool",
|
||||
"ToolChoice",
|
||||
"ToolResultContent",
|
||||
"ToolsCapability",
|
||||
"ToolUseContent",
|
||||
"UnsubscribeRequest",
|
||||
"UrlElicitationRequiredError",
|
||||
"stdio_client",
|
||||
"stdio_server",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
6
.venv/lib/python3.11/site-packages/mcp/cli/__init__.py
Normal file
6
.venv/lib/python3.11/site-packages/mcp/cli/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""FastMCP CLI package."""
|
||||
|
||||
from .cli import app
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
app()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
148
.venv/lib/python3.11/site-packages/mcp/cli/claude.py
Normal file
148
.venv/lib/python3.11/site-packages/mcp/cli/claude.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Claude app integration utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
MCP_PACKAGE = "mcp[cli]"
|
||||
|
||||
|
||||
def get_claude_config_path() -> Path | None: # pragma: no cover
|
||||
"""Get the Claude config directory based on platform."""
|
||||
if sys.platform == "win32":
|
||||
path = Path(Path.home(), "AppData", "Roaming", "Claude")
|
||||
elif sys.platform == "darwin":
|
||||
path = Path(Path.home(), "Library", "Application Support", "Claude")
|
||||
elif sys.platform.startswith("linux"):
|
||||
path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude")
|
||||
else:
|
||||
return None
|
||||
|
||||
if path.exists():
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def get_uv_path() -> str:
|
||||
"""Get the full path to the uv executable."""
|
||||
uv_path = shutil.which("uv")
|
||||
if not uv_path: # pragma: no cover
|
||||
logger.error(
|
||||
"uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH"
|
||||
)
|
||||
return "uv" # Fall back to just "uv" if not found
|
||||
return uv_path
|
||||
|
||||
|
||||
def update_claude_config(
|
||||
file_spec: str,
|
||||
server_name: str,
|
||||
*,
|
||||
with_editable: Path | None = None,
|
||||
with_packages: list[str] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
) -> bool:
|
||||
"""Add or update a FastMCP server in Claude's configuration.
|
||||
|
||||
Args:
|
||||
file_spec: Path to the server file, optionally with :object suffix
|
||||
server_name: Name for the server in Claude's config
|
||||
with_editable: Optional directory to install in editable mode
|
||||
with_packages: Optional list of additional packages to install
|
||||
env_vars: Optional dictionary of environment variables. These are merged with
|
||||
any existing variables, with new values taking precedence.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Claude Desktop's config directory is not found, indicating
|
||||
Claude Desktop may not be installed or properly set up.
|
||||
"""
|
||||
config_dir = get_claude_config_path()
|
||||
uv_path = get_uv_path()
|
||||
if not config_dir: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"Claude Desktop config directory not found. Please ensure Claude Desktop"
|
||||
" is installed and has been run at least once to initialize its config."
|
||||
)
|
||||
|
||||
config_file = config_dir / "claude_desktop_config.json"
|
||||
if not config_file.exists(): # pragma: no cover
|
||||
try:
|
||||
config_file.write_text("{}")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to create Claude config file",
|
||||
extra={
|
||||
"config_file": str(config_file),
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
config = json.loads(config_file.read_text())
|
||||
if "mcpServers" not in config:
|
||||
config["mcpServers"] = {}
|
||||
|
||||
# Always preserve existing env vars and merge with new ones
|
||||
if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: # pragma: no cover
|
||||
existing_env = config["mcpServers"][server_name]["env"]
|
||||
if env_vars:
|
||||
# New vars take precedence over existing ones
|
||||
env_vars = {**existing_env, **env_vars}
|
||||
else:
|
||||
env_vars = existing_env
|
||||
|
||||
# Build uv run command
|
||||
args = ["run", "--frozen"]
|
||||
|
||||
# Collect all packages in a set to deduplicate
|
||||
packages = {MCP_PACKAGE}
|
||||
if with_packages: # pragma: no cover
|
||||
packages.update(pkg for pkg in with_packages if pkg)
|
||||
|
||||
# Add all packages with --with
|
||||
for pkg in sorted(packages):
|
||||
args.extend(["--with", pkg])
|
||||
|
||||
if with_editable: # pragma: no cover
|
||||
args.extend(["--with-editable", str(with_editable)])
|
||||
|
||||
# Convert file path to absolute before adding to command
|
||||
# Split off any :object suffix first
|
||||
if ":" in file_spec:
|
||||
file_path, server_object = file_spec.rsplit(":", 1)
|
||||
file_spec = f"{Path(file_path).resolve()}:{server_object}"
|
||||
else: # pragma: no cover
|
||||
file_spec = str(Path(file_spec).resolve())
|
||||
|
||||
# Add fastmcp run command
|
||||
args.extend(["mcp", "run", file_spec])
|
||||
|
||||
server_config: dict[str, Any] = {"command": uv_path, "args": args}
|
||||
|
||||
# Add environment variables if specified
|
||||
if env_vars: # pragma: no cover
|
||||
server_config["env"] = env_vars
|
||||
|
||||
config["mcpServers"][server_name] = server_config
|
||||
|
||||
config_file.write_text(json.dumps(config, indent=2))
|
||||
logger.info(
|
||||
f"Added server '{server_name}' to Claude config",
|
||||
extra={"config_file": str(config_file)},
|
||||
)
|
||||
return True
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception(
|
||||
"Failed to update Claude config",
|
||||
extra={
|
||||
"config_file": str(config_file),
|
||||
},
|
||||
)
|
||||
return False
|
||||
488
.venv/lib/python3.11/site-packages/mcp/cli/cli.py
Normal file
488
.venv/lib/python3.11/site-packages/mcp/cli/cli.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""MCP CLI tools."""
|
||||
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
|
||||
from mcp.server import FastMCP
|
||||
from mcp.server import Server as LowLevelServer
|
||||
|
||||
try:
|
||||
import typer
|
||||
except ImportError: # pragma: no cover
|
||||
print("Error: typer is required. Install with 'pip install mcp[cli]'")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from mcp.cli import claude
|
||||
from mcp.server.fastmcp.utilities.logging import get_logger
|
||||
except ImportError: # pragma: no cover
|
||||
print("Error: mcp.server.fastmcp is not installed or not in PYTHONPATH")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import dotenv
|
||||
except ImportError: # pragma: no cover
|
||||
dotenv = None
|
||||
|
||||
logger = get_logger("cli")
|
||||
|
||||
app = typer.Typer(
|
||||
name="mcp",
|
||||
help="MCP development tools",
|
||||
add_completion=False,
|
||||
no_args_is_help=True, # Show help if no args provided
|
||||
)
|
||||
|
||||
|
||||
def _get_npx_command():
|
||||
"""Get the correct npx command for the current platform."""
|
||||
if sys.platform == "win32":
|
||||
# Try both npx.cmd and npx.exe on Windows
|
||||
for cmd in ["npx.cmd", "npx.exe", "npx"]:
|
||||
try:
|
||||
subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True)
|
||||
return cmd
|
||||
except subprocess.CalledProcessError:
|
||||
continue
|
||||
return None
|
||||
return "npx" # On Unix-like systems, just use npx
|
||||
|
||||
|
||||
def _parse_env_var(env_var: str) -> tuple[str, str]: # pragma: no cover
|
||||
"""Parse environment variable string in format KEY=VALUE."""
|
||||
if "=" not in env_var:
|
||||
logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE")
|
||||
sys.exit(1)
|
||||
key, value = env_var.split("=", 1)
|
||||
return key.strip(), value.strip()
|
||||
|
||||
|
||||
def _build_uv_command(
|
||||
file_spec: str,
|
||||
with_editable: Path | None = None,
|
||||
with_packages: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Build the uv run command that runs an MCP server through mcp run."""
|
||||
cmd = ["uv"]
|
||||
|
||||
cmd.extend(["run", "--with", "mcp"])
|
||||
|
||||
if with_editable:
|
||||
cmd.extend(["--with-editable", str(with_editable)])
|
||||
|
||||
if with_packages:
|
||||
for pkg in with_packages:
|
||||
if pkg: # pragma: no cover
|
||||
cmd.extend(["--with", pkg])
|
||||
|
||||
# Add mcp run command
|
||||
cmd.extend(["mcp", "run", file_spec])
|
||||
return cmd
|
||||
|
||||
|
||||
def _parse_file_path(file_spec: str) -> tuple[Path, str | None]:
|
||||
"""Parse a file path that may include a server object specification.
|
||||
|
||||
Args:
|
||||
file_spec: Path to file, optionally with :object suffix
|
||||
|
||||
Returns:
|
||||
Tuple of (file_path, server_object)
|
||||
"""
|
||||
# First check if we have a Windows path (e.g., C:\...)
|
||||
has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":"
|
||||
|
||||
# Split on the last colon, but only if it's not part of the Windows drive letter
|
||||
# and there's actually another colon in the string after the drive letter
|
||||
if ":" in (file_spec[2:] if has_windows_drive else file_spec):
|
||||
file_str, server_object = file_spec.rsplit(":", 1)
|
||||
else:
|
||||
file_str, server_object = file_spec, None
|
||||
|
||||
# Resolve the file path
|
||||
file_path = Path(file_str).expanduser().resolve()
|
||||
if not file_path.exists():
|
||||
logger.error(f"File not found: {file_path}")
|
||||
sys.exit(1)
|
||||
if not file_path.is_file():
|
||||
logger.error(f"Not a file: {file_path}")
|
||||
sys.exit(1)
|
||||
|
||||
return file_path, server_object
|
||||
|
||||
|
||||
def _import_server(file: Path, server_object: str | None = None): # pragma: no cover
|
||||
"""Import an MCP server from a file.
|
||||
|
||||
Args:
|
||||
file: Path to the file
|
||||
server_object: Optional object name in format "module:object" or just "object"
|
||||
|
||||
Returns:
|
||||
The server object
|
||||
"""
|
||||
# Add parent directory to Python path so imports can be resolved
|
||||
file_dir = str(file.parent)
|
||||
if file_dir not in sys.path:
|
||||
sys.path.insert(0, file_dir)
|
||||
|
||||
# Import the module
|
||||
spec = importlib.util.spec_from_file_location("server_module", file)
|
||||
if not spec or not spec.loader:
|
||||
logger.error("Could not load module", extra={"file": str(file)})
|
||||
sys.exit(1)
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
def _check_server_object(server_object: Any, object_name: str):
|
||||
"""Helper function to check that the server object is supported
|
||||
|
||||
Args:
|
||||
server_object: The server object to check.
|
||||
|
||||
Returns:
|
||||
True if it's supported.
|
||||
"""
|
||||
if not isinstance(server_object, FastMCP):
|
||||
logger.error(f"The server object {object_name} is of type {type(server_object)} (expecting {FastMCP}).")
|
||||
if isinstance(server_object, LowLevelServer):
|
||||
logger.warning(
|
||||
"Note that only FastMCP server is supported. Low level Server class is not yet supported."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
# If no object specified, try common server names
|
||||
if not server_object:
|
||||
# Look for the most common server object names
|
||||
for name in ["mcp", "server", "app"]:
|
||||
if hasattr(module, name):
|
||||
if not _check_server_object(getattr(module, name), f"{file}:{name}"):
|
||||
logger.error(f"Ignoring object '{file}:{name}' as it's not a valid server object")
|
||||
continue
|
||||
return getattr(module, name)
|
||||
|
||||
logger.error(
|
||||
f"No server object found in {file}. Please either:\n"
|
||||
"1. Use a standard variable name (mcp, server, or app)\n"
|
||||
"2. Specify the object name with file:object syntax"
|
||||
"3. If the server creates the FastMCP object within main() "
|
||||
" or another function, refactor the FastMCP object to be a "
|
||||
" global variable named mcp, server, or app.",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Handle module:object syntax
|
||||
if ":" in server_object:
|
||||
module_name, object_name = server_object.split(":", 1)
|
||||
try:
|
||||
server_module = importlib.import_module(module_name)
|
||||
server = getattr(server_module, object_name, None)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
f"Could not import module '{module_name}'",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Just object name
|
||||
server = getattr(module, server_object, None)
|
||||
|
||||
if server is None:
|
||||
logger.error(
|
||||
f"Server object '{server_object}' not found",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if not _check_server_object(server, server_object):
|
||||
sys.exit(1)
|
||||
|
||||
return server
|
||||
|
||||
|
||||
@app.command()
|
||||
def version() -> None: # pragma: no cover
|
||||
"""Show the MCP version."""
|
||||
try:
|
||||
version = importlib.metadata.version("mcp")
|
||||
print(f"MCP version {version}")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
print("MCP version unknown (package not installed)")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dev(
|
||||
file_spec: str = typer.Argument(
|
||||
...,
|
||||
help="Python file to run, optionally with :object suffix",
|
||||
),
|
||||
with_editable: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--with-editable",
|
||||
"-e",
|
||||
help="Directory containing pyproject.toml to install in editable mode",
|
||||
exists=True,
|
||||
file_okay=False,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
with_packages: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
"--with",
|
||||
help="Additional packages to install",
|
||||
),
|
||||
] = [],
|
||||
) -> None: # pragma: no cover
|
||||
"""Run an MCP server with the MCP Inspector."""
|
||||
file, server_object = _parse_file_path(file_spec)
|
||||
|
||||
logger.debug(
|
||||
"Starting dev server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"server_object": server_object,
|
||||
"with_editable": str(with_editable) if with_editable else None,
|
||||
"with_packages": with_packages,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Import server to get dependencies
|
||||
server = _import_server(file, server_object)
|
||||
if hasattr(server, "dependencies"):
|
||||
with_packages = list(set(with_packages + server.dependencies))
|
||||
|
||||
uv_cmd = _build_uv_command(file_spec, with_editable, with_packages)
|
||||
|
||||
# Get the correct npx command
|
||||
npx_cmd = _get_npx_command()
|
||||
if not npx_cmd:
|
||||
logger.error(
|
||||
"npx not found. Please ensure Node.js and npm are properly installed and added to your system PATH."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Run the MCP Inspector command with shell=True on Windows
|
||||
shell = sys.platform == "win32"
|
||||
process = subprocess.run(
|
||||
[npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd,
|
||||
check=True,
|
||||
shell=shell,
|
||||
env=dict(os.environ.items()), # Convert to list of tuples for env update
|
||||
)
|
||||
sys.exit(process.returncode)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(
|
||||
"Dev server failed",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"error": str(e),
|
||||
"returncode": e.returncode,
|
||||
},
|
||||
)
|
||||
sys.exit(e.returncode)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"npx not found. Please ensure Node.js and npm are properly installed "
|
||||
"and added to your system PATH. You may need to restart your terminal "
|
||||
"after installation.",
|
||||
extra={"file": str(file)},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def run(
|
||||
file_spec: str = typer.Argument(
|
||||
...,
|
||||
help="Python file to run, optionally with :object suffix",
|
||||
),
|
||||
transport: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"--transport",
|
||||
"-t",
|
||||
help="Transport protocol to use (stdio or sse)",
|
||||
),
|
||||
] = None,
|
||||
) -> None: # pragma: no cover
|
||||
"""Run an MCP server.
|
||||
|
||||
The server can be specified in two ways:\n
|
||||
1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n
|
||||
2. Import approach: server.py:app - imports and runs the specified server object.\n\n
|
||||
|
||||
Note: This command runs the server directly. You are responsible for ensuring
|
||||
all dependencies are available.\n
|
||||
For dependency management, use `mcp install` or `mcp dev` instead.
|
||||
""" # noqa: E501
|
||||
file, server_object = _parse_file_path(file_spec)
|
||||
|
||||
logger.debug(
|
||||
"Running server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"server_object": server_object,
|
||||
"transport": transport,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Import and get server object
|
||||
server = _import_server(file, server_object)
|
||||
|
||||
# Run the server
|
||||
kwargs = {}
|
||||
if transport:
|
||||
kwargs["transport"] = transport
|
||||
|
||||
server.run(**kwargs)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to run server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
},
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def install(
|
||||
file_spec: str = typer.Argument(
|
||||
...,
|
||||
help="Python file to run, optionally with :object suffix",
|
||||
),
|
||||
server_name: Annotated[
|
||||
str | None,
|
||||
typer.Option(
|
||||
"--name",
|
||||
"-n",
|
||||
help="Custom name for the server (defaults to server's name attribute or file name)",
|
||||
),
|
||||
] = None,
|
||||
with_editable: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--with-editable",
|
||||
"-e",
|
||||
help="Directory containing pyproject.toml to install in editable mode",
|
||||
exists=True,
|
||||
file_okay=False,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
with_packages: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
"--with",
|
||||
help="Additional packages to install",
|
||||
),
|
||||
] = [],
|
||||
env_vars: Annotated[
|
||||
list[str],
|
||||
typer.Option(
|
||||
"--env-var",
|
||||
"-v",
|
||||
help="Environment variables in KEY=VALUE format",
|
||||
),
|
||||
] = [],
|
||||
env_file: Annotated[
|
||||
Path | None,
|
||||
typer.Option(
|
||||
"--env-file",
|
||||
"-f",
|
||||
help="Load environment variables from a .env file",
|
||||
exists=True,
|
||||
file_okay=True,
|
||||
dir_okay=False,
|
||||
resolve_path=True,
|
||||
),
|
||||
] = None,
|
||||
) -> None: # pragma: no cover
|
||||
"""Install an MCP server in the Claude desktop app.
|
||||
|
||||
Environment variables are preserved once added and only updated if new values
|
||||
are explicitly provided.
|
||||
"""
|
||||
file, server_object = _parse_file_path(file_spec)
|
||||
|
||||
logger.debug(
|
||||
"Installing server",
|
||||
extra={
|
||||
"file": str(file),
|
||||
"server_name": server_name,
|
||||
"server_object": server_object,
|
||||
"with_editable": str(with_editable) if with_editable else None,
|
||||
"with_packages": with_packages,
|
||||
},
|
||||
)
|
||||
|
||||
if not claude.get_claude_config_path():
|
||||
logger.error("Claude app not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Try to import server to get its name, but fall back to file name if dependencies
|
||||
# missing
|
||||
name = server_name
|
||||
server = None
|
||||
if not name:
|
||||
try:
|
||||
server = _import_server(file, server_object)
|
||||
name = server.name
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.debug(
|
||||
"Could not import server (likely missing dependencies), using file name",
|
||||
extra={"error": str(e)},
|
||||
)
|
||||
name = file.stem
|
||||
|
||||
# Get server dependencies if available
|
||||
server_dependencies = getattr(server, "dependencies", []) if server else []
|
||||
if server_dependencies:
|
||||
with_packages = list(set(with_packages + server_dependencies))
|
||||
|
||||
# Process environment variables if provided
|
||||
env_dict: dict[str, str] | None = None
|
||||
if env_file or env_vars:
|
||||
env_dict = {}
|
||||
# Load from .env file if specified
|
||||
if env_file:
|
||||
if dotenv:
|
||||
try:
|
||||
env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None}
|
||||
except (OSError, ValueError):
|
||||
logger.exception("Failed to load .env file")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error("python-dotenv is not installed. Cannot load .env file.")
|
||||
sys.exit(1)
|
||||
|
||||
# Add command line environment variables
|
||||
for env_var in env_vars:
|
||||
key, value = _parse_env_var(env_var)
|
||||
env_dict[key] = value
|
||||
|
||||
if claude.update_claude_config(
|
||||
file_spec,
|
||||
name,
|
||||
with_editable=with_editable,
|
||||
with_packages=with_packages,
|
||||
env_vars=env_dict,
|
||||
):
|
||||
logger.info(f"Successfully installed {name} in Claude app")
|
||||
else:
|
||||
logger.error(f"Failed to install {name} in Claude app")
|
||||
sys.exit(1)
|
||||
85
.venv/lib/python3.11/site-packages/mcp/client/__main__.py
Normal file
85
.venv/lib/python3.11/site-packages/mcp/client/__main__.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
return
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
client_info: types.Implementation | None = None,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
message_handler=message_handler,
|
||||
client_info=client_info,
|
||||
) as session:
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
def cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("command_or_url", help="Command or URL to connect to")
|
||||
parser.add_argument("args", nargs="*", help="Additional arguments")
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--env",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("KEY", "VALUE"),
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.auth.oauth2 import (
|
||||
OAuthClientProvider,
|
||||
PKCEParameters,
|
||||
TokenStorage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OAuthClientProvider",
|
||||
"OAuthFlowError",
|
||||
"OAuthRegistrationError",
|
||||
"OAuthTokenError",
|
||||
"PKCEParameters",
|
||||
"TokenStorage",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,10 @@
|
||||
class OAuthFlowError(Exception):
|
||||
"""Base exception for OAuth flow errors."""
|
||||
|
||||
|
||||
class OAuthTokenError(OAuthFlowError):
|
||||
"""Raised when token operations fail."""
|
||||
|
||||
|
||||
class OAuthRegistrationError(OAuthFlowError):
|
||||
"""Raised when client registration fails."""
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
OAuth client credential extensions for MCP.
|
||||
|
||||
Provides OAuth providers for machine-to-machine authentication flows:
|
||||
- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret
|
||||
- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication
|
||||
(typically using a pre-built JWT from workload identity federation)
|
||||
- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1)
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
|
||||
class ClientCredentialsOAuthProvider(OAuthClientProvider):
|
||||
"""OAuth provider for client_credentials grant with client_id + client_secret.
|
||||
|
||||
This provider sets client_info directly, bypassing dynamic client registration.
|
||||
Use this when you already have client credentials (client_id and client_secret).
|
||||
|
||||
Example:
|
||||
```python
|
||||
provider = ClientCredentialsOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
client_secret="my-client-secret",
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
storage: TokenStorage,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic",
|
||||
scopes: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize client_credentials OAuth provider.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL.
|
||||
storage: Token storage implementation.
|
||||
client_id: The OAuth client ID.
|
||||
client_secret: The OAuth client secret.
|
||||
token_endpoint_auth_method: Authentication method for token endpoint.
|
||||
Either "client_secret_basic" (default) or "client_secret_post".
|
||||
scopes: Optional space-separated list of scopes to request.
|
||||
"""
|
||||
# Build minimal client_metadata for the base class
|
||||
client_metadata = OAuthClientMetadata(
|
||||
redirect_uris=None,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
scope=scopes,
|
||||
)
|
||||
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
|
||||
# Store client_info to be set during _initialize - no dynamic registration needed
|
||||
self._fixed_client_info = OAuthClientInformationFull(
|
||||
redirect_uris=None,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
scope=scopes,
|
||||
)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Load stored tokens and set pre-configured client_info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = self._fixed_client_info
|
||||
self._initialized = True
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform client_credentials authorization."""
|
||||
return await self._exchange_token_client_credentials()
|
||||
|
||||
async def _exchange_token_client_credentials(self) -> httpx.Request:
|
||||
"""Build token exchange request for client_credentials grant."""
|
||||
token_data: dict[str, Any] = {
|
||||
"grant_type": "client_credentials",
|
||||
}
|
||||
|
||||
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# Use standard auth methods (client_secret_basic, client_secret_post, none)
|
||||
token_data, headers = self.context.prepare_token_auth(token_data, headers)
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope:
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request("POST", token_url, data=token_data, headers=headers)
|
||||
|
||||
|
||||
def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]:
|
||||
"""Create an assertion provider that returns a static JWT token.
|
||||
|
||||
Use this when you have a pre-built JWT (e.g., from workload identity federation)
|
||||
that doesn't need the audience parameter.
|
||||
|
||||
Example:
|
||||
```python
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
token: The pre-built JWT assertion string.
|
||||
|
||||
Returns:
|
||||
An async callback suitable for use as an assertion_provider.
|
||||
"""
|
||||
|
||||
async def provider(audience: str) -> str:
|
||||
return token
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
class SignedJWTParameters(BaseModel):
|
||||
"""Parameters for creating SDK-signed JWT assertions.
|
||||
|
||||
Use `create_assertion_provider()` to create an assertion provider callback
|
||||
for use with `PrivateKeyJWTOAuthProvider`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
jwt_params = SignedJWTParameters(
|
||||
issuer="my-client-id",
|
||||
subject="my-client-id",
|
||||
signing_key=private_key_pem,
|
||||
)
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=jwt_params.create_assertion_provider(),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
issuer: str = Field(description="Issuer for JWT assertions (typically client_id).")
|
||||
subject: str = Field(description="Subject identifier for JWT assertions (typically client_id).")
|
||||
signing_key: str = Field(description="Private key for JWT signing (PEM format).")
|
||||
signing_algorithm: str = Field(default="RS256", description="Algorithm for signing JWT assertions.")
|
||||
lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
|
||||
additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.")
|
||||
|
||||
def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]:
|
||||
"""Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider.
|
||||
|
||||
Returns:
|
||||
An async callback that takes the audience (authorization server issuer URL)
|
||||
and returns a signed JWT assertion.
|
||||
"""
|
||||
|
||||
async def provider(audience: str) -> str:
|
||||
now = int(time.time())
|
||||
claims: dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"sub": self.subject,
|
||||
"aud": audience,
|
||||
"exp": now + self.lifetime_seconds,
|
||||
"iat": now,
|
||||
"jti": str(uuid4()),
|
||||
}
|
||||
if self.additional_claims:
|
||||
claims.update(self.additional_claims)
|
||||
|
||||
return jwt.encode(claims, self.signing_key, algorithm=self.signing_algorithm)
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
class PrivateKeyJWTOAuthProvider(OAuthClientProvider):
|
||||
"""OAuth provider for client_credentials grant with private_key_jwt authentication.
|
||||
|
||||
Uses RFC 7523 Section 2.2 for client authentication via JWT assertion.
|
||||
|
||||
The JWT assertion's audience MUST be the authorization server's issuer identifier
|
||||
(per RFC 7523bis security updates). The `assertion_provider` callback receives
|
||||
this audience value and must return a JWT with that audience.
|
||||
|
||||
**Option 1: Pre-built JWT via Workload Identity Federation**
|
||||
|
||||
In production scenarios, the JWT assertion is typically obtained from a workload
|
||||
identity provider (e.g., GCP, AWS IAM, Azure AD):
|
||||
|
||||
```python
|
||||
async def get_workload_identity_token(audience: str) -> str:
|
||||
# Fetch JWT from your identity provider
|
||||
# The JWT's audience must match the provided audience parameter
|
||||
return await fetch_token_from_identity_provider(audience=audience)
|
||||
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=get_workload_identity_token,
|
||||
)
|
||||
```
|
||||
|
||||
**Option 2: Static pre-built JWT**
|
||||
|
||||
If you have a static JWT that doesn't need the audience parameter:
|
||||
|
||||
```python
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=static_assertion_provider(my_prebuilt_jwt),
|
||||
)
|
||||
```
|
||||
|
||||
**Option 3: SDK-signed JWT (for testing/simple setups)**
|
||||
|
||||
For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`:
|
||||
|
||||
```python
|
||||
jwt_params = SignedJWTParameters(
|
||||
issuer="my-client-id",
|
||||
subject="my-client-id",
|
||||
signing_key=private_key_pem,
|
||||
)
|
||||
provider = PrivateKeyJWTOAuthProvider(
|
||||
server_url="https://api.example.com",
|
||||
storage=my_token_storage,
|
||||
client_id="my-client-id",
|
||||
assertion_provider=jwt_params.create_assertion_provider(),
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
storage: TokenStorage,
|
||||
client_id: str,
|
||||
assertion_provider: Callable[[str], Awaitable[str]],
|
||||
scopes: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize private_key_jwt OAuth provider.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL.
|
||||
storage: Token storage implementation.
|
||||
client_id: The OAuth client ID.
|
||||
assertion_provider: Async callback that takes the audience (authorization
|
||||
server's issuer identifier) and returns a JWT assertion. Use
|
||||
`SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs,
|
||||
`static_assertion_provider()` for pre-built JWTs, or provide your own
|
||||
callback for workload identity federation.
|
||||
scopes: Optional space-separated list of scopes to request.
|
||||
"""
|
||||
# Build minimal client_metadata for the base class
|
||||
client_metadata = OAuthClientMetadata(
|
||||
redirect_uris=None,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method="private_key_jwt",
|
||||
scope=scopes,
|
||||
)
|
||||
super().__init__(server_url, client_metadata, storage, None, None, 300.0)
|
||||
self._assertion_provider = assertion_provider
|
||||
# Store client_info to be set during _initialize - no dynamic registration needed
|
||||
self._fixed_client_info = OAuthClientInformationFull(
|
||||
redirect_uris=None,
|
||||
client_id=client_id,
|
||||
grant_types=["client_credentials"],
|
||||
token_endpoint_auth_method="private_key_jwt",
|
||||
scope=scopes,
|
||||
)
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Load stored tokens and set pre-configured client_info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = self._fixed_client_info
|
||||
self._initialized = True
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform client_credentials authorization with private_key_jwt."""
|
||||
return await self._exchange_token_client_credentials()
|
||||
|
||||
async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None:
|
||||
"""Add JWT assertion for client authentication to token endpoint parameters."""
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover
|
||||
|
||||
# Audience MUST be the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01
|
||||
audience = str(self.context.oauth_metadata.issuer)
|
||||
assertion = await self._assertion_provider(audience)
|
||||
|
||||
# RFC 7523 Section 2.2: client authentication via JWT
|
||||
token_data["client_assertion"] = assertion
|
||||
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
|
||||
async def _exchange_token_client_credentials(self) -> httpx.Request:
|
||||
"""Build token exchange request for client_credentials grant with private_key_jwt."""
|
||||
token_data: dict[str, Any] = {
|
||||
"grant_type": "client_credentials",
|
||||
}
|
||||
|
||||
headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# Add JWT client authentication (RFC 7523 Section 2.2)
|
||||
await self._add_client_authentication_jwt(token_data=token_data)
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope:
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request("POST", token_url, data=token_data, headers=headers)
|
||||
|
||||
|
||||
class JWTParameters(BaseModel):
|
||||
"""JWT parameters."""
|
||||
|
||||
assertion: str | None = Field(
|
||||
default=None,
|
||||
description="JWT assertion for JWT authentication. "
|
||||
"Will be used instead of generating a new assertion if provided.",
|
||||
)
|
||||
|
||||
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
|
||||
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
|
||||
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
|
||||
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
|
||||
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
|
||||
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
|
||||
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
|
||||
|
||||
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
|
||||
if self.assertion is not None:
|
||||
# Prebuilt JWT (e.g. acquired out-of-band)
|
||||
assertion = self.assertion
|
||||
else:
|
||||
if not self.jwt_signing_key:
|
||||
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
|
||||
if not self.issuer:
|
||||
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
|
||||
if not self.subject:
|
||||
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
|
||||
|
||||
audience = self.audience if self.audience else with_audience_fallback
|
||||
if not audience:
|
||||
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
|
||||
|
||||
now = int(time.time())
|
||||
claims: dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"sub": self.subject,
|
||||
"aud": audience,
|
||||
"exp": now + self.jwt_lifetime_seconds,
|
||||
"iat": now,
|
||||
"jti": str(uuid4()),
|
||||
}
|
||||
claims.update(self.claims or {})
|
||||
|
||||
assertion = jwt.encode(
|
||||
claims,
|
||||
self.jwt_signing_key,
|
||||
algorithm=self.jwt_signing_algorithm or "RS256",
|
||||
)
|
||||
return assertion
|
||||
|
||||
|
||||
class RFC7523OAuthClientProvider(OAuthClientProvider):
|
||||
"""OAuth client provider for RFC 7523 jwt-bearer grant.
|
||||
|
||||
.. deprecated::
|
||||
Use :class:`ClientCredentialsOAuthProvider` for client_credentials with
|
||||
client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for
|
||||
client_credentials with private_key_jwt authentication instead.
|
||||
|
||||
This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1)
|
||||
where the JWT itself is the authorization grant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
jwt_parameters: JWTParameters | None = None,
|
||||
) -> None:
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"RFC7523OAuthClientProvider is deprecated. Use ClientCredentialsOAuthProvider "
|
||||
"or PrivateKeyJWTOAuthProvider instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
|
||||
self.jwt_parameters = jwt_parameters
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
|
||||
) -> httpx.Request: # pragma: no cover
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
token_data = token_data or {}
|
||||
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
|
||||
self._add_client_authentication_jwt(token_data=token_data)
|
||||
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
|
||||
"""Perform the authorization flow."""
|
||||
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
|
||||
token_request = await self._exchange_token_jwt_bearer()
|
||||
return token_request
|
||||
else:
|
||||
return await super()._perform_authorization()
|
||||
|
||||
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
|
||||
"""Add JWT assertion for client authentication to token endpoint parameters."""
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
|
||||
token_data["client_assertion"] = assertion
|
||||
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
# We need to set the audience to the resource server, the audience is difference from the one in claims
|
||||
# it represents the resource server that will validate the token
|
||||
token_data["audience"] = self.context.get_resource_url()
|
||||
|
||||
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
|
||||
"""Build token exchange request for JWT bearer grant."""
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
token_data = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
"assertion": assertion,
|
||||
}
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
616
.venv/lib/python3.11/site-packages/mcp/client/auth/oauth2.py
Normal file
616
.venv/lib/python3.11/site-packages/mcp/client/auth/oauth2.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import quote, urlencode, urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
|
||||
from mcp.client.auth.utils import (
|
||||
build_oauth_authorization_server_metadata_discovery_urls,
|
||||
build_protected_resource_metadata_discovery_urls,
|
||||
create_client_info_from_metadata_url,
|
||||
create_client_registration_request,
|
||||
create_oauth_metadata_request,
|
||||
extract_field_from_www_auth,
|
||||
extract_resource_metadata_from_www_auth,
|
||||
extract_scope_from_www_auth,
|
||||
get_client_metadata_scopes,
|
||||
handle_auth_metadata_response,
|
||||
handle_protected_resource_response,
|
||||
handle_registration_response,
|
||||
handle_token_response_scopes,
|
||||
is_valid_client_metadata_url,
|
||||
should_use_client_metadata_url,
|
||||
)
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.shared.auth_utils import (
|
||||
calculate_token_expiry,
|
||||
check_resource_allowed,
|
||||
resource_url_from_server_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PKCEParameters(BaseModel):
|
||||
"""PKCE (Proof Key for Code Exchange) parameters."""
|
||||
|
||||
code_verifier: str = Field(..., min_length=43, max_length=128)
|
||||
code_challenge: str = Field(..., min_length=43, max_length=128)
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PKCEParameters":
|
||||
"""Generate new PKCE parameters."""
|
||||
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
|
||||
|
||||
|
||||
class TokenStorage(Protocol):
|
||||
"""Protocol for token storage implementations."""
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
"""Get stored tokens."""
|
||||
...
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
"""Store tokens."""
|
||||
...
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
"""Get stored client information."""
|
||||
...
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""Store client information."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthContext:
|
||||
"""OAuth flow context."""
|
||||
|
||||
server_url: str
|
||||
client_metadata: OAuthClientMetadata
|
||||
storage: TokenStorage
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
|
||||
timeout: float = 300.0
|
||||
client_metadata_url: str | None = None
|
||||
|
||||
# Discovered metadata
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None = None
|
||||
oauth_metadata: OAuthMetadata | None = None
|
||||
auth_server_url: str | None = None
|
||||
protocol_version: str | None = None
|
||||
|
||||
# Client registration
|
||||
client_info: OAuthClientInformationFull | None = None
|
||||
|
||||
# Token management
|
||||
current_tokens: OAuthToken | None = None
|
||||
token_expiry_time: float | None = None
|
||||
|
||||
# State
|
||||
lock: anyio.Lock = field(default_factory=anyio.Lock)
|
||||
|
||||
def get_authorization_base_url(self, server_url: str) -> str:
|
||||
"""Extract base URL by removing path component."""
|
||||
parsed = urlparse(server_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def update_token_expiry(self, token: OAuthToken) -> None:
|
||||
"""Update token expiry time using shared util function."""
|
||||
self.token_expiry_time = calculate_token_expiry(token.expires_in)
|
||||
|
||||
def is_token_valid(self) -> bool:
|
||||
"""Check if current token is valid."""
|
||||
return bool(
|
||||
self.current_tokens
|
||||
and self.current_tokens.access_token
|
||||
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
|
||||
)
|
||||
|
||||
def can_refresh_token(self) -> bool:
|
||||
"""Check if token can be refreshed."""
|
||||
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""Clear current tokens."""
|
||||
self.current_tokens = None
|
||||
self.token_expiry_time = None
|
||||
|
||||
def get_resource_url(self) -> str:
|
||||
"""Get resource URL for RFC 8707.
|
||||
|
||||
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
|
||||
"""
|
||||
resource = resource_url_from_server_url(self.server_url)
|
||||
|
||||
# If PRM provides a resource that's a valid parent, use it
|
||||
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
|
||||
prm_resource = str(self.protected_resource_metadata.resource)
|
||||
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
|
||||
resource = prm_resource
|
||||
|
||||
return resource
|
||||
|
||||
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
|
||||
"""Determine if the resource parameter should be included in OAuth requests.
|
||||
|
||||
Returns True if:
|
||||
- Protected resource metadata is available, OR
|
||||
- MCP-Protocol-Version header is 2025-06-18 or later
|
||||
"""
|
||||
# If we have protected resource metadata, include the resource param
|
||||
if self.protected_resource_metadata is not None:
|
||||
return True
|
||||
|
||||
# If no protocol version provided, don't include resource param
|
||||
if not protocol_version:
|
||||
return False
|
||||
|
||||
# Check if protocol version is 2025-06-18 or later
|
||||
# Version format is YYYY-MM-DD, so string comparison works
|
||||
return protocol_version >= "2025-06-18"
|
||||
|
||||
def prepare_token_auth(
|
||||
self, data: dict[str, str], headers: dict[str, str] | None = None
|
||||
) -> tuple[dict[str, str], dict[str, str]]:
|
||||
"""Prepare authentication for token requests.
|
||||
|
||||
Args:
|
||||
data: The form data to send
|
||||
headers: Optional headers dict to update
|
||||
|
||||
Returns:
|
||||
Tuple of (updated_data, updated_headers)
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {} # pragma: no cover
|
||||
|
||||
if not self.client_info:
|
||||
return data, headers # pragma: no cover
|
||||
|
||||
auth_method = self.client_info.token_endpoint_auth_method
|
||||
|
||||
if auth_method == "client_secret_basic" and self.client_info.client_id and self.client_info.client_secret:
|
||||
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
|
||||
encoded_id = quote(self.client_info.client_id, safe="")
|
||||
encoded_secret = quote(self.client_info.client_secret, safe="")
|
||||
credentials = f"{encoded_id}:{encoded_secret}"
|
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode()
|
||||
headers["Authorization"] = f"Basic {encoded_credentials}"
|
||||
# Don't include client_secret in body for basic auth
|
||||
data = {k: v for k, v in data.items() if k != "client_secret"}
|
||||
elif auth_method == "client_secret_post" and self.client_info.client_secret:
|
||||
# Include client_secret in request body
|
||||
data["client_secret"] = self.client_info.client_secret
|
||||
# For auth_method == "none", don't add any client_secret
|
||||
|
||||
return data, headers
|
||||
|
||||
|
||||
class OAuthClientProvider(httpx.Auth):
|
||||
"""
|
||||
OAuth2 authentication for httpx.
|
||||
Handles OAuth flow with automatic client registration and token storage.
|
||||
"""
|
||||
|
||||
requires_response_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
client_metadata_url: str | None = None,
|
||||
):
|
||||
"""Initialize OAuth2 authentication.
|
||||
|
||||
Args:
|
||||
server_url: The MCP server URL.
|
||||
client_metadata: OAuth client metadata for registration.
|
||||
storage: Token storage implementation.
|
||||
redirect_handler: Handler for authorization redirects.
|
||||
callback_handler: Handler for authorization callbacks.
|
||||
timeout: Timeout for the OAuth flow.
|
||||
client_metadata_url: URL-based client ID. When provided and the server
|
||||
advertises client_id_metadata_document_supported=true, this URL will be
|
||||
used as the client_id instead of performing dynamic client registration.
|
||||
Must be a valid HTTPS URL with a non-root pathname.
|
||||
|
||||
Raises:
|
||||
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
|
||||
with a non-root pathname.
|
||||
"""
|
||||
# Validate client_metadata_url if provided
|
||||
if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url):
|
||||
raise ValueError(
|
||||
f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}"
|
||||
)
|
||||
|
||||
self.context = OAuthContext(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=timeout,
|
||||
client_metadata_url=client_metadata_url,
|
||||
)
|
||||
self._initialized = False
|
||||
|
||||
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
self.context.protected_resource_metadata = metadata
|
||||
if metadata.authorization_servers: # pragma: no branch
|
||||
self.context.auth_server_url = str(metadata.authorization_servers[0])
|
||||
return True
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
|
||||
return False
|
||||
elif response.status_code == 404: # pragma: no cover
|
||||
# Not found - try next URL in fallback chain
|
||||
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
|
||||
return False
|
||||
else:
|
||||
# Other error - fail immediately
|
||||
raise OAuthFlowError(
|
||||
f"Protected Resource Metadata request failed: {response.status_code}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform the authorization flow."""
|
||||
auth_code, code_verifier = await self._perform_authorization_code_grant()
|
||||
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
|
||||
return token_request
|
||||
|
||||
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
|
||||
"""Perform the authorization redirect and get auth code."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.redirect_handler:
|
||||
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.callback_handler:
|
||||
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
|
||||
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
auth_endpoint = urljoin(auth_base_url, "/authorize")
|
||||
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
|
||||
|
||||
# Generate PKCE parameters
|
||||
pkce_params = PKCEParameters.generate()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
auth_params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"state": state,
|
||||
"code_challenge": pkce_params.code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
auth_params["scope"] = self.context.client_metadata.scope
|
||||
|
||||
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
await self.context.redirect_handler(authorization_url)
|
||||
|
||||
# Wait for callback
|
||||
auth_code, returned_state = await self.context.callback_handler()
|
||||
|
||||
if returned_state is None or not secrets.compare_digest(returned_state, state):
|
||||
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError("No authorization code received") # pragma: no cover
|
||||
|
||||
# Return auth code and code verifier for token exchange
|
||||
return auth_code, pkce_params.code_verifier
|
||||
|
||||
def _get_token_endpoint(self) -> str:
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint)
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
return token_url
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
|
||||
) -> httpx.Request:
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
token_data = token_data or {}
|
||||
token_data.update(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
)
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
# Prepare authentication based on preferred method
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
token_data, headers = self.context.prepare_token_auth(token_data, headers)
|
||||
|
||||
return httpx.Request("POST", token_url, data=token_data, headers=headers)
|
||||
|
||||
async def _handle_token_response(self, response: httpx.Response) -> None:
|
||||
"""Handle token exchange response."""
|
||||
if response.status_code != 200:
|
||||
body = await response.aread() # pragma: no cover
|
||||
body_text = body.decode("utf-8") # pragma: no cover
|
||||
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
|
||||
|
||||
# Parse and validate response with scope validation
|
||||
token_response = await handle_token_response_scopes(response)
|
||||
|
||||
# Store tokens in context
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
async def _refresh_token(self) -> httpx.Request:
|
||||
"""Build token refresh request."""
|
||||
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
|
||||
raise OAuthTokenError("No refresh token available") # pragma: no cover
|
||||
|
||||
if not self.context.client_info or not self.context.client_info.client_id:
|
||||
raise OAuthTokenError("No client info available") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
|
||||
refresh_data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self.context.current_tokens.refresh_token,
|
||||
"client_id": self.context.client_info.client_id,
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
# Prepare authentication based on preferred method
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
|
||||
|
||||
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)
|
||||
|
||||
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
|
||||
"""Handle token refresh response. Returns True if successful."""
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token refresh failed: {response.status_code}")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
return True
|
||||
except ValidationError:
|
||||
logger.exception("Invalid refresh response")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
async def _initialize(self) -> None: # pragma: no cover
|
||||
"""Load stored tokens and client info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = await self.context.storage.get_client_info()
|
||||
self._initialized = True
|
||||
|
||||
def _add_auth_header(self, request: httpx.Request) -> None:
|
||||
"""Add authorization header to request if we have valid tokens."""
|
||||
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
|
||||
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
|
||||
|
||||
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
|
||||
content = await response.aread()
|
||||
metadata = OAuthMetadata.model_validate_json(content)
|
||||
self.context.oauth_metadata = metadata
|
||||
|
||||
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||
"""HTTPX auth flow integration."""
|
||||
async with self.context.lock:
|
||||
if not self._initialized:
|
||||
await self._initialize() # pragma: no cover
|
||||
|
||||
# Capture protocol version from request headers
|
||||
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
|
||||
|
||||
if not self.context.is_token_valid() and self.context.can_refresh_token():
|
||||
# Try to refresh token
|
||||
refresh_request = await self._refresh_token() # pragma: no cover
|
||||
refresh_response = yield refresh_request # pragma: no cover
|
||||
|
||||
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
|
||||
# Refresh failed, need full re-authentication
|
||||
self._initialized = False
|
||||
|
||||
if self.context.is_token_valid():
|
||||
self._add_auth_header(request)
|
||||
|
||||
response = yield request
|
||||
|
||||
if response.status_code == 401:
|
||||
# Perform full OAuth flow
|
||||
try:
|
||||
# OAuth flow must be inline due to generator constraints
|
||||
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
|
||||
|
||||
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
|
||||
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url, self.context.server_url
|
||||
)
|
||||
|
||||
for url in prm_discovery_urls: # pragma: no branch
|
||||
discovery_request = create_oauth_metadata_request(url)
|
||||
|
||||
discovery_response = yield discovery_request # sending request
|
||||
|
||||
prm = await handle_protected_resource_response(discovery_response)
|
||||
if prm:
|
||||
self.context.protected_resource_metadata = prm
|
||||
|
||||
# todo: try all authorization_servers to find the OASM
|
||||
assert (
|
||||
len(prm.authorization_servers) > 0
|
||||
) # this is always true as authorization_servers has a min length of 1
|
||||
|
||||
self.context.auth_server_url = str(prm.authorization_servers[0])
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Protected resource metadata discovery failed: {url}")
|
||||
|
||||
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
self.context.auth_server_url, self.context.server_url
|
||||
)
|
||||
|
||||
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
|
||||
for url in asm_discovery_urls: # pragma: no cover
|
||||
oauth_metadata_request = create_oauth_metadata_request(url)
|
||||
oauth_metadata_response = yield oauth_metadata_request
|
||||
|
||||
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
|
||||
if not ok:
|
||||
break
|
||||
if ok and asm:
|
||||
self.context.oauth_metadata = asm
|
||||
break
|
||||
else:
|
||||
logger.debug(f"OAuth metadata discovery failed: {url}")
|
||||
|
||||
# Step 3: Apply scope selection strategy
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
extract_scope_from_www_auth(response),
|
||||
self.context.protected_resource_metadata,
|
||||
self.context.oauth_metadata,
|
||||
)
|
||||
|
||||
# Step 4: Register client or use URL-based client ID (CIMD)
|
||||
if not self.context.client_info:
|
||||
if should_use_client_metadata_url(
|
||||
self.context.oauth_metadata, self.context.client_metadata_url
|
||||
):
|
||||
# Use URL-based client ID (CIMD)
|
||||
logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}")
|
||||
client_information = create_client_info_from_metadata_url(
|
||||
self.context.client_metadata_url, # type: ignore[arg-type]
|
||||
redirect_uris=self.context.client_metadata.redirect_uris,
|
||||
)
|
||||
self.context.client_info = client_information
|
||||
await self.context.storage.set_client_info(client_information)
|
||||
else:
|
||||
# Fallback to Dynamic Client Registration
|
||||
registration_request = create_client_registration_request(
|
||||
self.context.oauth_metadata,
|
||||
self.context.client_metadata,
|
||||
self.context.get_authorization_base_url(self.context.server_url),
|
||||
)
|
||||
registration_response = yield registration_request
|
||||
client_information = await handle_registration_response(registration_response)
|
||||
self.context.client_info = client_information
|
||||
await self.context.storage.set_client_info(client_information)
|
||||
|
||||
# Step 5: Perform authorization and complete token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
elif response.status_code == 403:
|
||||
# Step 1: Extract error field from WWW-Authenticate header
|
||||
error = extract_field_from_www_auth(response, "error")
|
||||
|
||||
# Step 2: Check if we need to step-up authorization
|
||||
if error == "insufficient_scope": # pragma: no branch
|
||||
try:
|
||||
# Step 2a: Update the required scopes
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
|
||||
)
|
||||
|
||||
# Step 2b: Perform (re-)authorization and token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
336
.venv/lib/python3.11/site-packages/mcp/client/auth/utils.py
Normal file
336
.venv/lib/python3.11/site-packages/mcp/client/auth/utils.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import Request, Response
|
||||
from pydantic import AnyUrl, ValidationError
|
||||
|
||||
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
|
||||
"""
|
||||
Extract field from WWW-Authenticate header.
|
||||
|
||||
Returns:
|
||||
Field value if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
www_auth_header = response.headers.get("WWW-Authenticate")
|
||||
if not www_auth_header:
|
||||
return None
|
||||
|
||||
# Pattern matches: field_name="value" or field_name=value (unquoted)
|
||||
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
|
||||
match = re.search(pattern, www_auth_header)
|
||||
|
||||
if match:
|
||||
# Return quoted value if present, otherwise unquoted value
|
||||
return match.group(1) or match.group(2)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_scope_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract scope parameter from WWW-Authenticate header as per RFC6750.
|
||||
|
||||
Returns:
|
||||
Scope string if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
return extract_field_from_www_auth(response, "scope")
|
||||
|
||||
|
||||
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
|
||||
|
||||
Returns:
|
||||
Resource metadata URL if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
if not response or response.status_code != 401:
|
||||
return None # pragma: no cover
|
||||
|
||||
return extract_field_from_www_auth(response, "resource_metadata")
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build ordered list of URLs to try for protected resource metadata discovery.
|
||||
|
||||
Per SEP-985, the client MUST:
|
||||
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||||
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||||
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||||
|
||||
Args:
|
||||
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
|
||||
server_url: server url
|
||||
|
||||
Returns:
|
||||
Ordered list of URLs to try for discovery
|
||||
"""
|
||||
urls: list[str] = []
|
||||
|
||||
# Priority 1: WWW-Authenticate header with resource_metadata parameter
|
||||
if www_auth_url:
|
||||
urls.append(www_auth_url)
|
||||
|
||||
# Priority 2-3: Well-known URIs (RFC 9728)
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Priority 2: Path-based well-known URI (if server has a path component)
|
||||
if parsed.path and parsed.path != "/":
|
||||
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
|
||||
urls.append(path_based_url)
|
||||
|
||||
# Priority 3: Root-based well-known URI
|
||||
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
urls.append(root_based_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def get_client_metadata_scopes(
|
||||
www_authenticate_scope: str | None,
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None,
|
||||
authorization_server_metadata: OAuthMetadata | None = None,
|
||||
) -> str | None:
|
||||
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
|
||||
# Per MCP spec, scope selection priority order:
|
||||
# 1. Use scope from WWW-Authenticate header (if provided)
|
||||
# 2. Use all scopes from PRM scopes_supported (if available)
|
||||
# 3. Omit scope parameter if neither is available
|
||||
|
||||
if www_authenticate_scope is not None:
|
||||
# Priority 1: WWW-Authenticate header scope
|
||||
return www_authenticate_scope
|
||||
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
|
||||
# Priority 2: PRM scopes_supported
|
||||
return " ".join(protected_resource_metadata.scopes_supported)
|
||||
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
|
||||
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
|
||||
else:
|
||||
# Priority 3: Omit scope parameter
|
||||
return None
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Generate ordered list of (url, type) tuples for discovery attempts.
|
||||
|
||||
Args:
|
||||
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
|
||||
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
|
||||
"""
|
||||
|
||||
if not auth_server_url:
|
||||
# Legacy path using the 2025-03-26 spec:
|
||||
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
|
||||
parsed = urlparse(server_url)
|
||||
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
|
||||
|
||||
urls: list[str] = []
|
||||
parsed = urlparse(auth_server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# RFC 8414: Path-aware OAuth discovery
|
||||
if parsed.path and parsed.path != "/":
|
||||
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oauth_path))
|
||||
|
||||
# RFC 8414 section 5: Path-aware OIDC discovery
|
||||
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
|
||||
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
return urls
|
||||
|
||||
# OAuth root
|
||||
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
|
||||
|
||||
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
async def handle_protected_resource_response(
|
||||
response: Response,
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
return metadata
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
return None
|
||||
else:
|
||||
# Not found - try next URL in fallback chain
|
||||
return None
|
||||
|
||||
|
||||
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
asm = OAuthMetadata.model_validate_json(content)
|
||||
return True, asm
|
||||
except ValidationError: # pragma: no cover
|
||||
return True, None
|
||||
elif response.status_code < 400 or response.status_code >= 500:
|
||||
return False, None # Non-4XX error, stop trying
|
||||
return True, None
|
||||
|
||||
|
||||
def create_oauth_metadata_request(url: str) -> Request:
|
||||
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
|
||||
|
||||
|
||||
def create_client_registration_request(
|
||||
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
|
||||
) -> Request:
|
||||
"""Build registration request or skip if already registered."""
|
||||
|
||||
if auth_server_metadata and auth_server_metadata.registration_endpoint:
|
||||
registration_url = str(auth_server_metadata.registration_endpoint)
|
||||
else:
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
|
||||
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
|
||||
|
||||
|
||||
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
|
||||
"""Handle registration response."""
|
||||
if response.status_code not in (200, 201):
|
||||
await response.aread()
|
||||
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
client_info = OAuthClientInformationFull.model_validate_json(content)
|
||||
return client_info
|
||||
# self.context.client_info = client_info
|
||||
# await self.context.storage.set_client_info(client_info)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthRegistrationError(f"Invalid registration response: {e}")
|
||||
|
||||
|
||||
def is_valid_client_metadata_url(url: str | None) -> bool:
|
||||
"""Validate that a URL is suitable for use as a client_id (CIMD).
|
||||
|
||||
The URL must be HTTPS with a non-root pathname.
|
||||
|
||||
Args:
|
||||
url: The URL to validate
|
||||
|
||||
Returns:
|
||||
True if the URL is a valid HTTPS URL with a non-root pathname
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
return parsed.scheme == "https" and parsed.path not in ("", "/")
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def should_use_client_metadata_url(
|
||||
oauth_metadata: OAuthMetadata | None,
|
||||
client_metadata_url: str | None,
|
||||
) -> bool:
|
||||
"""Determine if URL-based client ID (CIMD) should be used instead of DCR.
|
||||
|
||||
URL-based client IDs should be used when:
|
||||
1. The server advertises client_id_metadata_document_supported=true
|
||||
2. The client has a valid client_metadata_url configured
|
||||
|
||||
Args:
|
||||
oauth_metadata: OAuth authorization server metadata
|
||||
client_metadata_url: URL-based client ID (already validated)
|
||||
|
||||
Returns:
|
||||
True if CIMD should be used, False if DCR should be used
|
||||
"""
|
||||
if not client_metadata_url:
|
||||
return False
|
||||
|
||||
if not oauth_metadata:
|
||||
return False
|
||||
|
||||
return oauth_metadata.client_id_metadata_document_supported is True
|
||||
|
||||
|
||||
def create_client_info_from_metadata_url(
|
||||
client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None
|
||||
) -> OAuthClientInformationFull:
|
||||
"""Create client information using a URL-based client ID (CIMD).
|
||||
|
||||
When using URL-based client IDs, the URL itself becomes the client_id
|
||||
and no client_secret is used (token_endpoint_auth_method="none").
|
||||
|
||||
Args:
|
||||
client_metadata_url: The URL to use as the client_id
|
||||
redirect_uris: The redirect URIs from the client metadata (passed through for
|
||||
compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata)
|
||||
|
||||
Returns:
|
||||
OAuthClientInformationFull with the URL as client_id
|
||||
"""
|
||||
return OAuthClientInformationFull(
|
||||
client_id=client_metadata_url,
|
||||
token_endpoint_auth_method="none",
|
||||
redirect_uris=redirect_uris,
|
||||
)
|
||||
|
||||
|
||||
async def handle_token_response_scopes(
|
||||
response: Response,
|
||||
) -> OAuthToken:
|
||||
"""Parse and validate token response with optional scope validation.
|
||||
|
||||
Parses token response JSON. Callers should check response.status_code before calling.
|
||||
|
||||
Args:
|
||||
response: HTTP response from token endpoint (status already checked by caller)
|
||||
|
||||
Returns:
|
||||
Validated OAuthToken model
|
||||
|
||||
Raises:
|
||||
OAuthTokenError: If response JSON is invalid
|
||||
"""
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
return token_response
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthTokenError(f"Invalid token response: {e}")
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Experimental client features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
"""
|
||||
|
||||
from mcp.client.experimental.tasks import ExperimentalClientFeatures
|
||||
|
||||
__all__ = ["ExperimentalClientFeatures"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Experimental task handler protocols for server -> client requests.
|
||||
|
||||
This module provides Protocol types and default handlers for when servers
|
||||
send task-related requests to clients (the reverse of normal client -> server flow).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Use cases:
|
||||
- Server sends task-augmented sampling/elicitation request to client
|
||||
- Client creates a local task, spawns background work, returns CreateTaskResult
|
||||
- Server polls client's task status via tasks/get, tasks/result, etc.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
|
||||
class GetTaskHandlerFnT(Protocol):
|
||||
"""Handler for tasks/get requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskRequestParams,
|
||||
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class GetTaskResultHandlerFnT(Protocol):
|
||||
"""Handler for tasks/result requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskPayloadRequestParams,
|
||||
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ListTasksHandlerFnT(Protocol):
|
||||
"""Handler for tasks/list requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.PaginatedRequestParams | None,
|
||||
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class CancelTaskHandlerFnT(Protocol):
|
||||
"""Handler for tasks/cancel requests from server.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CancelTaskRequestParams,
|
||||
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class TaskAugmentedSamplingFnT(Protocol):
|
||||
"""Handler for task-augmented sampling/createMessage requests from server.
|
||||
|
||||
When server sends a CreateMessageRequest with task field, this callback
|
||||
is invoked. The callback should create a task, spawn background work,
|
||||
and return CreateTaskResult immediately.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class TaskAugmentedElicitationFnT(Protocol):
|
||||
"""Handler for task-augmented elicitation/create requests from server.
|
||||
|
||||
When server sends an ElicitRequest with task field, this callback
|
||||
is invoked. The callback should create a task, spawn background work,
|
||||
and return CreateTaskResult immediately.
|
||||
|
||||
WARNING: This is experimental and may change without notice.
|
||||
"""
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
async def default_get_task_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskRequestParams,
|
||||
) -> types.GetTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/get not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_get_task_result_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.GetTaskPayloadRequestParams,
|
||||
) -> types.GetTaskPayloadResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/result not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_list_tasks_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.PaginatedRequestParams | None,
|
||||
) -> types.ListTasksResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/list not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_cancel_task_handler(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CancelTaskRequestParams,
|
||||
) -> types.CancelTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.METHOD_NOT_FOUND,
|
||||
message="tasks/cancel not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_task_augmented_sampling(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Task-augmented sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def default_task_augmented_elicitation(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
task_metadata: types.TaskMetadata,
|
||||
) -> types.CreateTaskResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Task-augmented elicitation not supported",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExperimentalTaskHandlers:
|
||||
"""Container for experimental task handlers.
|
||||
|
||||
Groups all task-related handlers that handle server -> client requests.
|
||||
This includes both pure task requests (get, list, cancel, result) and
|
||||
task-augmented request handlers (sampling, elicitation with task field).
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
handlers = ExperimentalTaskHandlers(
|
||||
get_task=my_get_task_handler,
|
||||
list_tasks=my_list_tasks_handler,
|
||||
)
|
||||
session = ClientSession(..., experimental_task_handlers=handlers)
|
||||
"""
|
||||
|
||||
# Pure task request handlers
|
||||
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
|
||||
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
|
||||
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
|
||||
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
|
||||
|
||||
# Task-augmented request handlers
|
||||
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
|
||||
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
|
||||
|
||||
def build_capability(self) -> types.ClientTasksCapability | None:
|
||||
"""Build ClientTasksCapability from the configured handlers.
|
||||
|
||||
Returns a capability object that reflects which handlers are configured
|
||||
(i.e., not using the default "not supported" handlers).
|
||||
|
||||
Returns:
|
||||
ClientTasksCapability if any handlers are provided, None otherwise
|
||||
"""
|
||||
has_list = self.list_tasks is not default_list_tasks_handler
|
||||
has_cancel = self.cancel_task is not default_cancel_task_handler
|
||||
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
|
||||
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
|
||||
|
||||
# If no handlers are provided, return None
|
||||
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
|
||||
return None
|
||||
|
||||
# Build requests capability if any request handlers are provided
|
||||
requests_capability: types.ClientTasksRequestsCapability | None = None
|
||||
if has_sampling or has_elicitation:
|
||||
requests_capability = types.ClientTasksRequestsCapability(
|
||||
sampling=types.TasksSamplingCapability(createMessage=types.TasksCreateMessageCapability())
|
||||
if has_sampling
|
||||
else None,
|
||||
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
|
||||
if has_elicitation
|
||||
else None,
|
||||
)
|
||||
|
||||
return types.ClientTasksCapability(
|
||||
list=types.TasksListCapability() if has_list else None,
|
||||
cancel=types.TasksCancelCapability() if has_cancel else None,
|
||||
requests=requests_capability,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handles_request(request: types.ServerRequest) -> bool:
|
||||
"""Check if this handler handles the given request type."""
|
||||
return isinstance(
|
||||
request.root,
|
||||
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self,
|
||||
ctx: RequestContext["ClientSession", Any],
|
||||
responder: RequestResponder[types.ServerRequest, types.ClientResult],
|
||||
) -> None:
|
||||
"""Handle a task-related request from the server.
|
||||
|
||||
Call handles_request() first to check if this handler can handle the request.
|
||||
"""
|
||||
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
|
||||
types.ClientResult | types.ErrorData
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.GetTaskRequest(params=params):
|
||||
response = await self.get_task(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.GetTaskPayloadRequest(params=params):
|
||||
response = await self.get_task_result(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListTasksRequest(params=params):
|
||||
response = await self.list_tasks(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.CancelTaskRequest(params=params):
|
||||
response = await self.cancel_task(ctx, params)
|
||||
client_response = client_response_type.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case _: # pragma: no cover
|
||||
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
default_task_augmented_sampling_callback = default_task_augmented_sampling
|
||||
default_task_augmented_elicitation_callback = default_task_augmented_elicitation
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Experimental client-side task support.
|
||||
|
||||
This module provides client methods for interacting with MCP tasks.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
# Call a tool as a task
|
||||
result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"})
|
||||
task_id = result.task.taskId
|
||||
|
||||
# Get task status
|
||||
status = await session.experimental.get_task(task_id)
|
||||
|
||||
# Get task result when complete
|
||||
if status.status == "completed":
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
|
||||
# List all tasks
|
||||
tasks = await session.experimental.list_tasks()
|
||||
|
||||
# Cancel a task
|
||||
await session.experimental.cancel_task(task_id)
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.experimental.tasks.polling import poll_until_terminal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.client.session import ClientSession
|
||||
|
||||
ResultT = TypeVar("ResultT", bound=types.Result)
|
||||
|
||||
|
||||
class ExperimentalClientFeatures:
|
||||
"""
|
||||
Experimental client features for tasks and other experimental APIs.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Access via session.experimental:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
"""
|
||||
|
||||
def __init__(self, session: "ClientSession") -> None:
|
||||
self._session = session
|
||||
|
||||
async def call_tool_as_task(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
*,
|
||||
ttl: int = 60000,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CreateTaskResult:
|
||||
"""Call a tool as a task, returning a CreateTaskResult for polling.
|
||||
|
||||
This is a convenience method for calling tools that support task execution.
|
||||
The server will return a task reference instead of the immediate result,
|
||||
which can then be polled via `get_task()` and retrieved via `get_task_result()`.
|
||||
|
||||
Args:
|
||||
name: The tool name
|
||||
arguments: Tool arguments
|
||||
ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute)
|
||||
meta: Optional metadata to include in the request
|
||||
|
||||
Returns:
|
||||
CreateTaskResult containing the task reference
|
||||
|
||||
Example:
|
||||
# Create task
|
||||
result = await session.experimental.call_tool_as_task(
|
||||
"long_running_tool", {"input": "data"}
|
||||
)
|
||||
task_id = result.task.taskId
|
||||
|
||||
# Poll for completion
|
||||
while True:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
if status.status == "completed":
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Get result
|
||||
final = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
_meta: types.RequestParams.Meta | None = None
|
||||
if meta is not None:
|
||||
_meta = types.RequestParams.Meta(**meta)
|
||||
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
params=types.CallToolRequestParams(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
task=types.TaskMetadata(ttl=ttl),
|
||||
_meta=_meta,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CreateTaskResult,
|
||||
)
|
||||
|
||||
async def get_task(self, task_id: str) -> types.GetTaskResult:
|
||||
"""
|
||||
Get the current status of a task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
GetTaskResult containing the task status and metadata
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetTaskRequest(
|
||||
params=types.GetTaskRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
types.GetTaskResult,
|
||||
)
|
||||
|
||||
async def get_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
result_type: type[ResultT],
|
||||
) -> ResultT:
|
||||
"""
|
||||
Get the result of a completed task.
|
||||
|
||||
The result type depends on the original request type:
|
||||
- tools/call tasks return CallToolResult
|
||||
- Other request types return their corresponding result type
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
result_type: The expected result type (e.g., CallToolResult)
|
||||
|
||||
Returns:
|
||||
The task result, validated against result_type
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetTaskPayloadRequest(
|
||||
params=types.GetTaskPayloadRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
result_type,
|
||||
)
|
||||
|
||||
async def list_tasks(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
) -> types.ListTasksResult:
|
||||
"""
|
||||
List all tasks.
|
||||
|
||||
Args:
|
||||
cursor: Optional pagination cursor
|
||||
|
||||
Returns:
|
||||
ListTasksResult containing tasks and optional next cursor
|
||||
"""
|
||||
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.ListTasksRequest(params=params),
|
||||
),
|
||||
types.ListTasksResult,
|
||||
)
|
||||
|
||||
async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
|
||||
"""
|
||||
Cancel a running task.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Returns:
|
||||
CancelTaskResult with the updated task state
|
||||
"""
|
||||
return await self._session.send_request(
|
||||
types.ClientRequest(
|
||||
types.CancelTaskRequest(
|
||||
params=types.CancelTaskRequestParams(taskId=task_id),
|
||||
)
|
||||
),
|
||||
types.CancelTaskResult,
|
||||
)
|
||||
|
||||
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
|
||||
"""
|
||||
Poll a task until it reaches a terminal status.
|
||||
|
||||
Yields GetTaskResult for each poll, allowing the caller to react to
|
||||
status changes (e.g., handle input_required). Exits when task reaches
|
||||
a terminal status (completed, failed, cancelled).
|
||||
|
||||
Respects the pollInterval hint from the server.
|
||||
|
||||
Args:
|
||||
task_id: The task identifier
|
||||
|
||||
Yields:
|
||||
GetTaskResult for each poll
|
||||
|
||||
Example:
|
||||
async for status in session.experimental.poll_task(task_id):
|
||||
print(f"Status: {status.status}")
|
||||
if status.status == "input_required":
|
||||
# Handle elicitation request via tasks/result
|
||||
pass
|
||||
|
||||
# Task is now terminal, get the result
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
async for status in poll_until_terminal(self.get_task, task_id):
|
||||
yield status
|
||||
615
.venv/lib/python3.11/site-packages/mcp/client/session.py
Normal file
615
.venv/lib/python3.11/site-packages/mcp/client/session.py
Normal file
@@ -0,0 +1,615 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol, overload
|
||||
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.experimental import ExperimentalClientFeatures
|
||||
from mcp.client.experimental.task_handlers import ExperimentalTaskHandlers
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
|
||||
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ElicitationFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
async def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_elicitation_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
) -> types.ElicitResult | types.ErrorData:
|
||||
return types.ErrorData( # pragma: no cover
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Elicitation not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
elicitation_callback: ElicitationFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
*,
|
||||
sampling_capabilities: types.SamplingCapability | None = None,
|
||||
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._sampling_capabilities = sampling_capabilities
|
||||
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
self._server_capabilities: types.ServerCapabilities | None = None
|
||||
self._experimental_features: ExperimentalClientFeatures | None = None
|
||||
|
||||
# Experimental: Task handlers (use defaults if not provided)
|
||||
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = (
|
||||
(self._sampling_capabilities or types.SamplingCapability())
|
||||
if self._sampling_callback is not _default_sampling_callback
|
||||
else None
|
||||
)
|
||||
elicitation = (
|
||||
types.ElicitationCapability(
|
||||
form=types.FormElicitationCapability(),
|
||||
url=types.UrlElicitationCapability(),
|
||||
)
|
||||
if self._elicitation_callback is not _default_elicitation_callback
|
||||
else None
|
||||
)
|
||||
roots = (
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
types.RootsCapability(listChanged=True)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
elicitation=elicitation,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
tasks=self._task_handlers.build_capability(),
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self._server_capabilities = result.capabilities
|
||||
|
||||
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
|
||||
|
||||
return result
|
||||
|
||||
def get_server_capabilities(self) -> types.ServerCapabilities | None:
|
||||
"""Return the server capabilities received during initialization.
|
||||
|
||||
Returns None if the session has not been initialized yet.
|
||||
"""
|
||||
return self._server_capabilities
|
||||
|
||||
@property
|
||||
def experimental(self) -> ExperimentalClientFeatures:
|
||||
"""Experimental APIs for tasks and other features.
|
||||
|
||||
WARNING: These APIs are experimental and may change without notice.
|
||||
|
||||
Example:
|
||||
status = await session.experimental.get_task(task_id)
|
||||
result = await session.experimental.get_task_result(task_id, CallToolResult)
|
||||
"""
|
||||
if self._experimental_features is None:
|
||||
self._experimental_features = ExperimentalClientFeatures(self)
|
||||
return self._experimental_features
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.PingRequest()),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resources(self) -> types.ListResourcesResult: ...
|
||||
|
||||
async def list_resources(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resource_templates(
|
||||
self, *, params: types.PaginatedRequestParams | None
|
||||
) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
async def list_resource_templates(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request with optional progress callback support."""
|
||||
|
||||
_meta: types.RequestParams.Meta | None = None
|
||||
if meta is not None:
|
||||
_meta = types.RequestParams.Meta(**meta)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
if not result.isError:
|
||||
await self._validate_tool_result(name, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
|
||||
"""Validate the structured content of a tool result against its output schema."""
|
||||
if name not in self._tool_output_schemas:
|
||||
# refresh output schema cache
|
||||
await self.list_tools()
|
||||
|
||||
output_schema = None
|
||||
if name in self._tool_output_schemas:
|
||||
output_schema = self._tool_output_schemas.get(name)
|
||||
else:
|
||||
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
|
||||
|
||||
if output_schema is not None:
|
||||
from jsonschema import SchemaError, ValidationError, validate
|
||||
|
||||
if result.structuredContent is None:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} has an output schema but did not return structured content"
|
||||
) # pragma: no cover
|
||||
try:
|
||||
validate(result.structuredContent, output_schema)
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
|
||||
except SchemaError as e: # pragma: no cover
|
||||
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_prompts(self) -> types.ListPromptsResult: ...
|
||||
|
||||
async def list_prompts(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
context_arguments: dict[str, str] | None = None,
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
context = None
|
||||
if context_arguments is not None:
|
||||
context = types.CompletionContext(arguments=context_arguments)
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
context=context,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_tools(self) -> types.ListToolsResult: ...
|
||||
|
||||
async def list_tools(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListToolsResult:
|
||||
"""Send a tools/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(types.ListToolsRequest(params=request_params)),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
# Cache tool output schemas for future validation
|
||||
# Note: don't clear the cache, as we may be using a cursor
|
||||
for tool in result.tools:
|
||||
self._tool_output_schemas[tool.name] = tool.outputSchema
|
||||
|
||||
return result
|
||||
|
||||
async def send_roots_list_changed(self) -> None: # pragma: no cover
|
||||
"""Send a roots/list_changed notification."""
|
||||
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
|
||||
|
||||
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
# Delegate to experimental task handler if applicable
|
||||
if self._task_handlers.handles_request(responder.request):
|
||||
with responder:
|
||||
await self._task_handlers.handle_request(ctx, responder)
|
||||
return None
|
||||
|
||||
# Core request handling
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
# Check if this is a task-augmented request
|
||||
if params.task is not None:
|
||||
response = await self._task_handlers.augmented_sampling(ctx, params, params.task)
|
||||
else:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ElicitRequest(params=params):
|
||||
with responder:
|
||||
# Check if this is a task-augmented request
|
||||
if params.task is not None:
|
||||
response = await self._task_handlers.augmented_elicitation(ctx, params, params.task)
|
||||
else:
|
||||
response = await self._elicitation_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest(): # pragma: no cover
|
||||
with responder:
|
||||
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
case _: # pragma: no cover
|
||||
pass # Task requests handled above by _task_handlers
|
||||
|
||||
return None
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
await self._message_handler(req)
|
||||
|
||||
async def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
await self._logging_callback(params)
|
||||
case types.ElicitCompleteNotification(params=params):
|
||||
# Handle elicitation completion notification
|
||||
# Clients MAY use this to retry requests or update UI
|
||||
# The notification contains the elicitationId of the completed elicitation
|
||||
pass
|
||||
case _:
|
||||
pass
|
||||
447
.venv/lib/python3.11/site-packages/mcp/client/session_group.py
Normal file
447
.venv/lib/python3.11/site-packages/mcp/client/session_group.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""
|
||||
SessionGroup concurrently manages multiple MCP session connections.
|
||||
|
||||
Tools, resources, and prompts are aggregated across servers. Servers may
|
||||
be connected to or disconnected from at any point after initialization.
|
||||
|
||||
This abstractions can handle naming collisions using a custom user-provided
|
||||
hook.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, TypeAlias, overload
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
import mcp
|
||||
from mcp import types
|
||||
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
from mcp.shared._httpx_utils import create_mcp_http_client
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import ProgressFnT
|
||||
|
||||
|
||||
class SseServerParameters(BaseModel):
|
||||
"""Parameters for intializing a sse_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: float = 5
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class StreamableHttpParameters(BaseModel):
|
||||
"""Parameters for intializing a streamable_http_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: timedelta = timedelta(seconds=30)
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
|
||||
|
||||
# Close the client session when the transport closes.
|
||||
terminate_on_close: bool = True
|
||||
|
||||
|
||||
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
|
||||
|
||||
# Use dataclass instead of pydantic BaseModel
|
||||
# because pydantic BaseModel cannot handle Protocol fields.
|
||||
@dataclass
|
||||
class ClientSessionParameters:
|
||||
"""Parameters for establishing a client session to an MCP server."""
|
||||
|
||||
read_timeout_seconds: timedelta | None = None
|
||||
sampling_callback: SamplingFnT | None = None
|
||||
elicitation_callback: ElicitationFnT | None = None
|
||||
list_roots_callback: ListRootsFnT | None = None
|
||||
logging_callback: LoggingFnT | None = None
|
||||
message_handler: MessageHandlerFnT | None = None
|
||||
client_info: types.Implementation | None = None
|
||||
|
||||
|
||||
class ClientSessionGroup:
|
||||
"""Client for managing connections to multiple MCP servers.
|
||||
|
||||
This class is responsible for encapsulating management of server connections.
|
||||
It aggregates tools, resources, and prompts from all connected servers.
|
||||
|
||||
For auxiliary handlers, such as resource subscription, this is delegated to
|
||||
the client and can be accessed via the session.
|
||||
|
||||
Example Usage:
|
||||
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
|
||||
async with ClientSessionGroup(component_name_hook=name_fn) as group:
|
||||
for server_param in server_params:
|
||||
await group.connect_to_server(server_param)
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
class _ComponentNames(BaseModel):
|
||||
"""Used for reverse index to find components."""
|
||||
|
||||
prompts: set[str] = set()
|
||||
resources: set[str] = set()
|
||||
tools: set[str] = set()
|
||||
|
||||
# Standard MCP components.
|
||||
_prompts: dict[str, types.Prompt]
|
||||
_resources: dict[str, types.Resource]
|
||||
_tools: dict[str, types.Tool]
|
||||
|
||||
# Client-server connection management.
|
||||
_sessions: dict[mcp.ClientSession, _ComponentNames]
|
||||
_tool_to_session: dict[str, mcp.ClientSession]
|
||||
_exit_stack: contextlib.AsyncExitStack
|
||||
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
|
||||
|
||||
# Optional fn consuming (component_name, serverInfo) for custom names.
|
||||
# This is provide a means to mitigate naming conflicts across servers.
|
||||
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
|
||||
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
|
||||
_component_name_hook: _ComponentNameHook | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exit_stack: contextlib.AsyncExitStack | None = None,
|
||||
component_name_hook: _ComponentNameHook | None = None,
|
||||
) -> None:
|
||||
"""Initializes the MCP client."""
|
||||
|
||||
self._tools = {}
|
||||
self._resources = {}
|
||||
self._prompts = {}
|
||||
|
||||
self._sessions = {}
|
||||
self._tool_to_session = {}
|
||||
if exit_stack is None:
|
||||
self._exit_stack = contextlib.AsyncExitStack()
|
||||
self._owns_exit_stack = True
|
||||
else:
|
||||
self._exit_stack = exit_stack
|
||||
self._owns_exit_stack = False
|
||||
self._session_exit_stacks = {}
|
||||
self._component_name_hook = component_name_hook
|
||||
|
||||
async def __aenter__(self) -> Self: # pragma: no cover
|
||||
# Enter the exit stack only if we created it ourselves
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
_exc_type: type[BaseException] | None,
|
||||
_exc_val: BaseException | None,
|
||||
_exc_tb: TracebackType | None,
|
||||
) -> bool | None: # pragma: no cover
|
||||
"""Closes session exit stacks and main exit stack upon completion."""
|
||||
|
||||
# Only close the main exit stack if we created it
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
# Concurrently close session stacks.
|
||||
async with anyio.create_task_group() as tg:
|
||||
for exit_stack in self._session_exit_stacks.values():
|
||||
tg.start_soon(exit_stack.aclose)
|
||||
|
||||
@property
|
||||
def sessions(self) -> list[mcp.ClientSession]:
|
||||
"""Returns the list of sessions being managed."""
|
||||
return list(self._sessions.keys()) # pragma: no cover
|
||||
|
||||
@property
|
||||
def prompts(self) -> dict[str, types.Prompt]:
|
||||
"""Returns the prompts as a dictionary of names to prompts."""
|
||||
return self._prompts
|
||||
|
||||
@property
|
||||
def resources(self) -> dict[str, types.Resource]:
|
||||
"""Returns the resources as a dictionary of names to resources."""
|
||||
return self._resources
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, types.Tool]:
|
||||
"""Returns the tools as a dictionary of names to tools."""
|
||||
return self._tools
|
||||
|
||||
@overload
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult: ...
|
||||
|
||||
@overload
|
||||
@deprecated("The 'args' parameter is deprecated. Use 'arguments' instead.")
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
args: dict[str, Any],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult: ...
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
args: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Executes a tool given its name and arguments."""
|
||||
session = self._tool_to_session[name]
|
||||
session_tool_name = self.tools[name].name
|
||||
return await session.call_tool(
|
||||
session_tool_name,
|
||||
arguments if args is None else args,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
progress_callback=progress_callback,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
|
||||
"""Disconnects from a single MCP server."""
|
||||
|
||||
session_known_for_components = session in self._sessions
|
||||
session_known_for_stack = session in self._session_exit_stacks
|
||||
|
||||
if not session_known_for_components and not session_known_for_stack:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message="Provided session is not managed or already disconnected.",
|
||||
)
|
||||
)
|
||||
|
||||
if session_known_for_components: # pragma: no cover
|
||||
component_names = self._sessions.pop(session) # Pop from _sessions tracking
|
||||
|
||||
# Remove prompts associated with the session.
|
||||
for name in component_names.prompts:
|
||||
if name in self._prompts:
|
||||
del self._prompts[name]
|
||||
# Remove resources associated with the session.
|
||||
for name in component_names.resources:
|
||||
if name in self._resources:
|
||||
del self._resources[name]
|
||||
# Remove tools associated with the session.
|
||||
for name in component_names.tools:
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
if name in self._tool_to_session:
|
||||
del self._tool_to_session[name]
|
||||
|
||||
# Clean up the session's resources via its dedicated exit stack
|
||||
if session_known_for_stack:
|
||||
session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover
|
||||
await session_stack_to_close.aclose() # pragma: no cover
|
||||
|
||||
async def connect_with_session(
|
||||
self, server_info: types.Implementation, session: mcp.ClientSession
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
await self._aggregate_components(server_info, session)
|
||||
return session
|
||||
|
||||
async def connect_to_server(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
session_params: ClientSessionParameters | None = None,
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
server_info, session = await self._establish_session(server_params, session_params or ClientSessionParameters())
|
||||
return await self.connect_with_session(server_info, session)
|
||||
|
||||
async def _establish_session(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
session_params: ClientSessionParameters,
|
||||
) -> tuple[types.Implementation, mcp.ClientSession]:
|
||||
"""Establish a client session to an MCP server."""
|
||||
|
||||
session_stack = contextlib.AsyncExitStack()
|
||||
try:
|
||||
# Create read and write streams that facilitate io with the server.
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
client = mcp.stdio_client(server_params)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
elif isinstance(server_params, SseServerParameters):
|
||||
client = sse_client(
|
||||
url=server_params.url,
|
||||
headers=server_params.headers,
|
||||
timeout=server_params.timeout,
|
||||
sse_read_timeout=server_params.sse_read_timeout,
|
||||
)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
else:
|
||||
httpx_client = create_mcp_http_client(
|
||||
headers=server_params.headers,
|
||||
timeout=httpx.Timeout(
|
||||
server_params.timeout.total_seconds(),
|
||||
read=server_params.sse_read_timeout.total_seconds(),
|
||||
),
|
||||
)
|
||||
await session_stack.enter_async_context(httpx_client)
|
||||
|
||||
client = streamable_http_client(
|
||||
url=server_params.url,
|
||||
http_client=httpx_client,
|
||||
terminate_on_close=server_params.terminate_on_close,
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(client)
|
||||
|
||||
session = await session_stack.enter_async_context(
|
||||
mcp.ClientSession(
|
||||
read,
|
||||
write,
|
||||
read_timeout_seconds=session_params.read_timeout_seconds,
|
||||
sampling_callback=session_params.sampling_callback,
|
||||
elicitation_callback=session_params.elicitation_callback,
|
||||
list_roots_callback=session_params.list_roots_callback,
|
||||
logging_callback=session_params.logging_callback,
|
||||
message_handler=session_params.message_handler,
|
||||
client_info=session_params.client_info,
|
||||
)
|
||||
)
|
||||
|
||||
result = await session.initialize()
|
||||
|
||||
# Session successfully initialized.
|
||||
# Store its stack and register the stack with the main group stack.
|
||||
self._session_exit_stacks[session] = session_stack
|
||||
# session_stack itself becomes a resource managed by the
|
||||
# main _exit_stack.
|
||||
await self._exit_stack.enter_async_context(session_stack)
|
||||
|
||||
return result.serverInfo, session
|
||||
except Exception: # pragma: no cover
|
||||
# If anything during this setup fails, ensure the session-specific
|
||||
# stack is closed.
|
||||
await session_stack.aclose()
|
||||
raise
|
||||
|
||||
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
|
||||
"""Aggregates prompts, resources, and tools from a given session."""
|
||||
|
||||
# Create a reverse index so we can find all prompts, resources, and
|
||||
# tools belonging to this session. Used for removing components from
|
||||
# the session group via self.disconnect_from_server.
|
||||
component_names = self._ComponentNames()
|
||||
|
||||
# Temporary components dicts. We do not want to modify the aggregate
|
||||
# lists in case of an intermediate failure.
|
||||
prompts_temp: dict[str, types.Prompt] = {}
|
||||
resources_temp: dict[str, types.Resource] = {}
|
||||
tools_temp: dict[str, types.Tool] = {}
|
||||
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
|
||||
|
||||
# Query the server for its prompts and aggregate to list.
|
||||
try:
|
||||
prompts = (await session.list_prompts()).prompts
|
||||
for prompt in prompts:
|
||||
name = self._component_name(prompt.name, server_info)
|
||||
prompts_temp[name] = prompt
|
||||
component_names.prompts.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch prompts: {err}")
|
||||
|
||||
# Query the server for its resources and aggregate to list.
|
||||
try:
|
||||
resources = (await session.list_resources()).resources
|
||||
for resource in resources:
|
||||
name = self._component_name(resource.name, server_info)
|
||||
resources_temp[name] = resource
|
||||
component_names.resources.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch resources: {err}")
|
||||
|
||||
# Query the server for its tools and aggregate to list.
|
||||
try:
|
||||
tools = (await session.list_tools()).tools
|
||||
for tool in tools:
|
||||
name = self._component_name(tool.name, server_info)
|
||||
tools_temp[name] = tool
|
||||
tool_to_session_temp[name] = session
|
||||
component_names.tools.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch tools: {err}")
|
||||
|
||||
# Clean up exit stack for session if we couldn't retrieve anything
|
||||
# from the server.
|
||||
if not any((prompts_temp, resources_temp, tools_temp)):
|
||||
del self._session_exit_stacks[session] # pragma: no cover
|
||||
|
||||
# Check for duplicates.
|
||||
matching_prompts = prompts_temp.keys() & self._prompts.keys()
|
||||
if matching_prompts:
|
||||
raise McpError( # pragma: no cover
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_prompts} already exist in group prompts.",
|
||||
)
|
||||
)
|
||||
matching_resources = resources_temp.keys() & self._resources.keys()
|
||||
if matching_resources:
|
||||
raise McpError( # pragma: no cover
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_resources} already exist in group resources.",
|
||||
)
|
||||
)
|
||||
matching_tools = tools_temp.keys() & self._tools.keys()
|
||||
if matching_tools:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_tools} already exist in group tools.",
|
||||
)
|
||||
)
|
||||
|
||||
# Aggregate components.
|
||||
self._sessions[session] = component_names
|
||||
self._prompts.update(prompts_temp)
|
||||
self._resources.update(resources_temp)
|
||||
self._tools.update(tools_temp)
|
||||
self._tool_to_session.update(tool_to_session_temp)
|
||||
|
||||
def _component_name(self, name: str, server_info: types.Implementation) -> str:
|
||||
if self._component_name_hook:
|
||||
return self._component_name_hook(name, server_info)
|
||||
return name
|
||||
164
.venv/lib/python3.11/site-packages/mcp/client/sse.py
Normal file
164
.venv/lib/python3.11/site-packages/mcp/client/sse.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
from httpx_sse._exceptions import SSEError
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
|
||||
query_params = parse_qs(urlparse(endpoint_url).query)
|
||||
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||
auth: httpx.Auth | None = None,
|
||||
on_session_created: Callable[[str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
on_session_created: Optional callback invoked with the session ID when received.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx_client_factory(
|
||||
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
|
||||
) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.debug(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if ( # pragma: no cover
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme != endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = ( # pragma: no cover
|
||||
f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg) # pragma: no cover
|
||||
raise ValueError(error_msg) # pragma: no cover
|
||||
|
||||
if on_session_created:
|
||||
session_id = _extract_session_id_from_endpoint(endpoint_url)
|
||||
if session_id:
|
||||
on_session_created(session_id)
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
# Skip empty data (keep-alive pings)
|
||||
if not sse.data:
|
||||
continue
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||
sse.data
|
||||
)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing server message") # pragma: no cover
|
||||
await read_stream_writer.send(exc) # pragma: no cover
|
||||
continue # pragma: no cover
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
case _: # pragma: no cover
|
||||
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
|
||||
except SSEError as sse_exc: # pragma: no cover
|
||||
logger.exception("Encountered SSE exception") # pragma: no cover
|
||||
raise sse_exc # pragma: no cover
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error in sse_reader") # pragma: no cover
|
||||
await read_stream_writer.send(exc) # pragma: no cover
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {session_message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Error in post_writer") # pragma: no cover
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
278
.venv/lib/python3.11/site-packages/mcp/client/stdio/__init__.py
Normal file
278
.venv/lib/python3.11/site-packages/mcp/client/stdio/__init__.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal, TextIO
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.abc import Process
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.os.posix.utilities import terminate_posix_process_tree
|
||||
from mcp.os.win32.utilities import (
|
||||
FallbackProcess,
|
||||
create_windows_process,
|
||||
get_windows_executable_command,
|
||||
terminate_windows_process_tree,
|
||||
)
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
[
|
||||
"APPDATA",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
"LOCALAPPDATA",
|
||||
"PATH",
|
||||
"PATHEXT",
|
||||
"PROCESSOR_ARCHITECTURE",
|
||||
"SYSTEMDRIVE",
|
||||
"SYSTEMROOT",
|
||||
"TEMP",
|
||||
"USERNAME",
|
||||
"USERPROFILE",
|
||||
]
|
||||
if sys.platform == "win32"
|
||||
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
||||
)
|
||||
|
||||
# Timeout for process termination before falling back to force kill
|
||||
PROCESS_TERMINATION_TIMEOUT = 2.0
|
||||
|
||||
|
||||
def get_default_environment() -> dict[str, str]:
|
||||
"""
|
||||
Returns a default environment object including only environment variables deemed
|
||||
safe to inherit.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
for key in DEFAULT_INHERITED_ENV_VARS:
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
continue # pragma: no cover
|
||||
|
||||
if value.startswith("()"): # pragma: no cover
|
||||
# Skip functions, which are a security risk
|
||||
continue # pragma: no cover
|
||||
|
||||
env[key] = value
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None = None
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
If not specified, the result of get_default_environment() will be used.
|
||||
"""
|
||||
|
||||
cwd: str | Path | None = None
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
"""
|
||||
The text encoding used when sending/receiving messages to the server
|
||||
|
||||
defaults to utf-8
|
||||
"""
|
||||
|
||||
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
|
||||
"""
|
||||
The text encoding error handler.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
try:
|
||||
command = _get_executable_command(server.command)
|
||||
|
||||
# Open process with stderr piped for capture
|
||||
process = await _create_platform_compatible_process(
|
||||
command=command,
|
||||
args=server.args,
|
||||
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
|
||||
errlog=errlog,
|
||||
cwd=server.cwd,
|
||||
)
|
||||
except OSError:
|
||||
# Clean up streams if process creation fails
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
raise
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Failed to parse JSONRPC message from server")
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError: # pragma: no cover
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError: # pragma: no cover
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
# MCP spec: stdio shutdown sequence
|
||||
# 1. Close input stream to server
|
||||
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
|
||||
# 3. Send SIGKILL if still not exited
|
||||
if process.stdin: # pragma: no branch
|
||||
try:
|
||||
await process.stdin.aclose()
|
||||
except Exception: # pragma: no cover
|
||||
# stdin might already be closed, which is fine
|
||||
pass
|
||||
|
||||
try:
|
||||
# Give the process time to exit gracefully after stdin closes
|
||||
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
|
||||
await process.wait()
|
||||
except TimeoutError:
|
||||
# Process didn't exit from stdin closure, use platform-specific termination
|
||||
# which handles SIGTERM -> SIGKILL escalation
|
||||
await _terminate_process_tree(process)
|
||||
except ProcessLookupError: # pragma: no cover
|
||||
# Process already exited, which is fine
|
||||
pass
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
|
||||
|
||||
def _get_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for the current platform.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Platform-appropriate command
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
return get_windows_executable_command(command)
|
||||
else:
|
||||
return command # pragma: no cover
|
||||
|
||||
|
||||
async def _create_platform_compatible_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a platform-compatible way.
|
||||
|
||||
Unix: Creates process in a new session/process group for killpg support
|
||||
Windows: Creates process in a Job Object for reliable child termination
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
process = await create_windows_process(command, args, env, errlog, cwd)
|
||||
else:
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
start_new_session=True,
|
||||
) # pragma: no cover
|
||||
|
||||
return process
|
||||
|
||||
|
||||
async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
|
||||
"""
|
||||
Terminate a process and all its children using platform-specific methods.
|
||||
|
||||
Unix: Uses os.killpg() for atomic process group termination
|
||||
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
await terminate_windows_process_tree(process, timeout_seconds)
|
||||
else: # pragma: no cover
|
||||
# FallbackProcess should only be used for Windows compatibility
|
||||
assert isinstance(process, Process)
|
||||
await terminate_posix_process_tree(process, timeout_seconds)
|
||||
Binary file not shown.
722
.venv/lib/python3.11/site-packages/mcp/client/streamable_http.py
Normal file
722
.venv/lib/python3.11/site-packages/mcp/client/streamable_http.py
Normal file
@@ -0,0 +1,722 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Any, overload
|
||||
from warnings import warn
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from mcp.shared._httpx_utils import (
|
||||
McpHttpClientFactory,
|
||||
create_mcp_http_client,
|
||||
)
|
||||
from mcp.shared.message import ClientMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception
|
||||
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
|
||||
StreamReader = MemoryObjectReceiveStream[SessionMessage]
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
|
||||
# Reconnection defaults
|
||||
DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry
|
||||
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
# Sentinel value for detecting unset optional parameters
|
||||
_UNSET = object()
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.AsyncClient
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
read_stream_writer: StreamWriter
|
||||
headers: dict[str, str] | None = None # Deprecated - no longer used
|
||||
sse_read_timeout: float | None = None # Deprecated - no longer used
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
@overload
|
||||
def __init__(self, url: str) -> None: ...
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
"Parameters headers, timeout, sse_read_timeout, and auth are deprecated. "
|
||||
"Configure these on the httpx.AsyncClient instead."
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Any = _UNSET,
|
||||
timeout: Any = _UNSET,
|
||||
sse_read_timeout: Any = _UNSET,
|
||||
auth: Any = _UNSET,
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
"""
|
||||
# Check for deprecated parameters and issue runtime warning
|
||||
deprecated_params: list[str] = []
|
||||
if headers is not _UNSET:
|
||||
deprecated_params.append("headers")
|
||||
if timeout is not _UNSET:
|
||||
deprecated_params.append("timeout")
|
||||
if sse_read_timeout is not _UNSET:
|
||||
deprecated_params.append("sse_read_timeout")
|
||||
if auth is not _UNSET:
|
||||
deprecated_params.append("auth")
|
||||
|
||||
if deprecated_params:
|
||||
warn(
|
||||
f"Parameters {', '.join(deprecated_params)} are deprecated and will be ignored. "
|
||||
"Configure these on the httpx.AsyncClient instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.session_id = None
|
||||
self.protocol_version = None
|
||||
|
||||
def _prepare_headers(self) -> dict[str, str]:
|
||||
"""Build MCP-specific request headers.
|
||||
|
||||
These headers will be merged with the httpx.AsyncClient's default headers,
|
||||
with these MCP-specific headers taking precedence.
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
# Add MCP protocol headers
|
||||
headers[ACCEPT] = f"{JSON}, {SSE}"
|
||||
headers[CONTENT_TYPE] = JSON
|
||||
# Add session headers if available
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
if self.protocol_version:
|
||||
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _maybe_extract_protocol_version_from_message(
|
||||
self,
|
||||
message: JSONRPCMessage,
|
||||
) -> None:
|
||||
"""Extract protocol version from initialization response message."""
|
||||
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
|
||||
try:
|
||||
# Parse the result as InitializeResult for type safety
|
||||
init_result = InitializeResult.model_validate(message.root.result)
|
||||
self.protocol_version = str(init_result.protocolVersion)
|
||||
logger.info(f"Negotiated protocol version: {self.protocol_version}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning(
|
||||
f"Failed to parse initialization response as InitializeResult: {exc}"
|
||||
) # pragma: no cover
|
||||
logger.warning(f"Raw result: {message.root.result}")
|
||||
|
||||
async def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
read_stream_writer: StreamWriter,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
||||
is_initialization: bool = False,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
# Handle priming events (empty data with ID) for resumability
|
||||
if not sse.data:
|
||||
# Call resumption callback for priming events that have an ID
|
||||
if sse.id and resumption_callback:
|
||||
await resumption_callback(sse.id)
|
||||
return False
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# Extract protocol version from initialization response
|
||||
if is_initialization:
|
||||
self._maybe_extract_protocol_version_from_message(message)
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
await resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing SSE message")
|
||||
await read_stream_writer.send(exc)
|
||||
return False
|
||||
else: # pragma: no cover
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
async def handle_get_stream(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages with auto-reconnect."""
|
||||
last_event_id: str | None = None
|
||||
retry_interval_ms: int | None = None
|
||||
attempt: int = 0
|
||||
|
||||
while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._prepare_headers()
|
||||
if last_event_id:
|
||||
headers[LAST_EVENT_ID] = last_event_id # pragma: no cover
|
||||
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
# Track last event ID for reconnection
|
||||
if sse.id:
|
||||
last_event_id = sse.id # pragma: no cover
|
||||
# Track retry interval from server
|
||||
if sse.retry is not None:
|
||||
retry_interval_ms = sse.retry # pragma: no cover
|
||||
|
||||
await self._handle_sse_event(sse, read_stream_writer)
|
||||
|
||||
# Stream ended normally (server closed) - reset attempt counter
|
||||
attempt = 0
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.debug(f"GET stream error: {exc}")
|
||||
attempt += 1
|
||||
|
||||
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
|
||||
logger.debug(f"GET stream max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
|
||||
return
|
||||
|
||||
# Wait before reconnecting
|
||||
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
|
||||
logger.info(f"GET stream disconnected, reconnecting in {delay_ms}ms...")
|
||||
await anyio.sleep(delay_ms / 1000.0)
|
||||
|
||||
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._prepare_headers()
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
async with aconnect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
await event_source.response.aclose()
|
||||
break
|
||||
|
||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._prepare_headers()
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
async with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404: # pragma: no branch
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
await self._send_session_terminated_error( # pragma: no cover
|
||||
ctx.read_stream_writer, # pragma: no cover
|
||||
message.root.id, # pragma: no cover
|
||||
) # pragma: no cover
|
||||
return # pragma: no cover
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||
# The server MUST NOT send a response to notifications.
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
content_type = response.headers.get(CONTENT_TYPE, "").lower()
|
||||
if content_type.startswith(JSON):
|
||||
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
|
||||
elif content_type.startswith(SSE):
|
||||
await self._handle_sse_response(response, ctx, is_initialization)
|
||||
else:
|
||||
await self._handle_unexpected_content_type( # pragma: no cover
|
||||
content_type, # pragma: no cover
|
||||
ctx.read_stream_writer, # pragma: no cover
|
||||
) # pragma: no cover
|
||||
|
||||
async def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
read_stream_writer: StreamWriter,
|
||||
is_initialization: bool = False,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = await response.aread()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
|
||||
# Extract protocol version from initialization response
|
||||
if is_initialization:
|
||||
self._maybe_extract_protocol_version_from_message(message)
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing JSON response")
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def _handle_sse_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
ctx: RequestContext,
|
||||
is_initialization: bool = False,
|
||||
) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
last_event_id: str | None = None
|
||||
retry_interval_ms: int | None = None
|
||||
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
# Track last event ID for potential reconnection
|
||||
if sse.id:
|
||||
last_event_id = sse.id
|
||||
|
||||
# Track retry interval from server
|
||||
if sse.retry is not None:
|
||||
retry_interval_ms = sse.retry
|
||||
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
is_initialization=is_initialization,
|
||||
)
|
||||
# If the SSE event indicates completion, like returning respose/error
|
||||
# break the loop
|
||||
if is_complete:
|
||||
await response.aclose()
|
||||
return # Normal completion, no reconnect needed
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.debug(f"SSE stream ended: {e}")
|
||||
|
||||
# Stream ended without response - reconnect if we received an event with ID
|
||||
if last_event_id is not None: # pragma: no branch
|
||||
logger.info("SSE stream disconnected, reconnecting...")
|
||||
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms)
|
||||
|
||||
async def _handle_reconnection(
|
||||
self,
|
||||
ctx: RequestContext,
|
||||
last_event_id: str,
|
||||
retry_interval_ms: int | None = None,
|
||||
attempt: int = 0,
|
||||
) -> None:
|
||||
"""Reconnect with Last-Event-ID to resume stream after server disconnect."""
|
||||
# Bail if max retries exceeded
|
||||
if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover
|
||||
logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded")
|
||||
return
|
||||
|
||||
# Always wait - use server value or default
|
||||
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
|
||||
await anyio.sleep(delay_ms / 1000.0)
|
||||
|
||||
headers = self._prepare_headers()
|
||||
headers[LAST_EVENT_ID] = last_event_id
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
try:
|
||||
async with aconnect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.info("Reconnected to SSE stream")
|
||||
|
||||
# Track for potential further reconnection
|
||||
reconnect_last_event_id: str = last_event_id
|
||||
reconnect_retry_ms = retry_interval_ms
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
if sse.id: # pragma: no branch
|
||||
reconnect_last_event_id = sse.id
|
||||
if sse.retry is not None:
|
||||
reconnect_retry_ms = sse.retry
|
||||
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
await event_source.response.aclose()
|
||||
return
|
||||
|
||||
# Stream ended again without response - reconnect again (reset attempt counter)
|
||||
logger.info("SSE stream disconnected, reconnecting...")
|
||||
await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0)
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.debug(f"Reconnection failed: {e}")
|
||||
# Try to reconnect again if we still have an event ID
|
||||
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
|
||||
|
||||
async def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None: # pragma: no cover
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
|
||||
logger.error(error_msg) # pragma: no cover
|
||||
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
|
||||
|
||||
async def _send_session_terminated_error(
|
||||
self,
|
||||
read_stream_writer: StreamWriter,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
async def post_writer(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
write_stream_reader: StreamReader,
|
||||
read_stream_writer: StreamWriter,
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
start_get_stream: Callable[[], None],
|
||||
tg: TaskGroup,
|
||||
) -> None:
|
||||
"""Handle writing requests to the server."""
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata
|
||||
if isinstance(session_message.metadata, ClientMessageMetadata)
|
||||
else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
read_stream_writer=read_stream_writer,
|
||||
)
|
||||
|
||||
async def handle_request_async():
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
|
||||
# If this is a request, start a new task to handle it
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
tg.start_soon(handle_request_async)
|
||||
else:
|
||||
await handle_request_async()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in post_writer") # pragma: no cover
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._prepare_headers()
|
||||
response = await client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code not in (200, 204):
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def streamable_http_client(
|
||||
url: str,
|
||||
*,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
terminate_on_close: bool = True,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
Args:
|
||||
url: The MCP server endpoint URL.
|
||||
http_client: Optional pre-configured httpx.AsyncClient. If None, a default
|
||||
client with recommended MCP timeouts will be created. To configure headers,
|
||||
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
|
||||
terminate_on_close: If True, send a DELETE request to terminate the session
|
||||
when the context exits.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- read_stream: Stream for reading messages from the server
|
||||
- write_stream: Stream for sending messages to the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
|
||||
Example:
|
||||
See examples/snippets/clients/ for usage patterns.
|
||||
"""
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
||||
|
||||
# Determine if we need to create and manage the client
|
||||
client_provided = http_client is not None
|
||||
client = http_client
|
||||
|
||||
if client is None:
|
||||
# Create default client with recommended MCP timeouts
|
||||
client = create_mcp_http_client()
|
||||
|
||||
transport = StreamableHTTPTransport(url)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
|
||||
|
||||
async with contextlib.AsyncExitStack() as stack:
|
||||
# Only manage client lifecycle if we created it
|
||||
if not client_provided:
|
||||
await stack.enter_async_context(client)
|
||||
|
||||
def start_get_stream() -> None:
|
||||
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
|
||||
|
||||
tg.start_soon(
|
||||
transport.post_writer,
|
||||
client,
|
||||
write_stream_reader,
|
||||
read_stream_writer,
|
||||
write_stream,
|
||||
start_get_stream,
|
||||
tg,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
read_stream,
|
||||
write_stream,
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
await transport.terminate_session(client)
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@deprecated("Use `streamable_http_client` instead.")
|
||||
async def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
terminate_on_close: bool = True,
|
||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
# Convert timeout parameters
|
||||
timeout_seconds = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||
sse_read_timeout_seconds = (
|
||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||
)
|
||||
|
||||
# Create httpx client using the factory with old-style parameters
|
||||
client = httpx_client_factory(
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(timeout_seconds, read=sse_read_timeout_seconds),
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
# Manage client lifecycle since we created it
|
||||
async with client:
|
||||
async with streamable_http_client(
|
||||
url,
|
||||
http_client=client,
|
||||
terminate_on_close=terminate_on_close,
|
||||
) as streams:
|
||||
yield streams
|
||||
86
.venv/lib/python3.11/site-packages/mcp/client/websocket.py
Normal file
86
.venv/lib/python3.11/site-packages/mcp/client/websocket.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_client(
|
||||
url: str,
|
||||
) -> AsyncGenerator[
|
||||
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
WebSocket client transport for MCP, symmetrical to the server version.
|
||||
|
||||
Connects to 'url' using the 'mcp' subprotocol, then yields:
|
||||
(read_stream, write_stream)
|
||||
|
||||
- read_stream: As you read from this stream, you'll receive either valid
|
||||
JSONRPCMessage objects or Exception objects (when validation fails).
|
||||
- write_stream: Write JSONRPCMessage objects to this stream to send them
|
||||
over the WebSocket to the server.
|
||||
"""
|
||||
|
||||
# Create two in-memory streams:
|
||||
# - One for incoming messages (read_stream, written by ws_reader)
|
||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
# Connect using websockets, requesting the "mcp" subprotocol
|
||||
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
|
||||
|
||||
async def ws_reader():
|
||||
"""
|
||||
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
|
||||
and sends them into read_stream_writer.
|
||||
"""
|
||||
async with read_stream_writer:
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except ValidationError as exc: # pragma: no cover
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def ws_writer():
|
||||
"""
|
||||
Reads JSON-RPC messages from write_stream_reader and
|
||||
sends them to the server.
|
||||
"""
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
# Convert to a dict, then to JSON
|
||||
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
await ws.send(json.dumps(msg_dict))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start reader and writer tasks
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
|
||||
# Yield the receive/send streams
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
# Once the caller's 'async with' block exits, we shut down
|
||||
tg.cancel_scope.cancel()
|
||||
1
.venv/lib/python3.11/site-packages/mcp/os/__init__.py
Normal file
1
.venv/lib/python3.11/site-packages/mcp/os/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Platform-specific utilities for MCP."""
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
"""POSIX-specific utilities for MCP."""
|
||||
Binary file not shown.
Binary file not shown.
60
.venv/lib/python3.11/site-packages/mcp/os/posix/utilities.py
Normal file
60
.venv/lib/python3.11/site-packages/mcp/os/posix/utilities.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
POSIX-specific functionality for stdio client operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
|
||||
import anyio
|
||||
from anyio.abc import Process
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None:
|
||||
"""
|
||||
Terminate a process and all its children on POSIX systems.
|
||||
|
||||
Uses os.killpg() for atomic process group termination.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
|
||||
"""
|
||||
pid = getattr(process, "pid", None) or getattr(getattr(process, "popen", None), "pid", None)
|
||||
if not pid:
|
||||
# No PID means there's no process to terminate - it either never started,
|
||||
# already exited, or we have an invalid process object
|
||||
return
|
||||
|
||||
try:
|
||||
pgid = os.getpgid(pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
|
||||
with anyio.move_on_after(timeout_seconds):
|
||||
while True:
|
||||
try:
|
||||
# Check if process group still exists (signal 0 = check only)
|
||||
os.killpg(pgid, 0)
|
||||
await anyio.sleep(0.1)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
|
||||
try:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
except (ProcessLookupError, PermissionError, OSError) as e:
|
||||
logger.warning(f"Process group termination failed for PID {pid}: {e}, falling back to simple terminate")
|
||||
try:
|
||||
process.terminate()
|
||||
with anyio.fail_after(timeout_seconds):
|
||||
await process.wait()
|
||||
except Exception:
|
||||
logger.warning(f"Process termination failed for PID {pid}, attempting force kill")
|
||||
try:
|
||||
process.kill()
|
||||
except Exception:
|
||||
logger.exception(f"Failed to kill process {pid}")
|
||||
@@ -0,0 +1 @@
|
||||
"""Windows-specific utilities for MCP."""
|
||||
Binary file not shown.
Binary file not shown.
338
.venv/lib/python3.11/site-packages/mcp/os/win32/utilities.py
Normal file
338
.venv/lib/python3.11/site-packages/mcp/os/win32/utilities.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Windows-specific functionality for stdio client operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import BinaryIO, TextIO, cast
|
||||
|
||||
import anyio
|
||||
from anyio import to_thread
|
||||
from anyio.abc import Process
|
||||
from anyio.streams.file import FileReadStream, FileWriteStream
|
||||
from typing_extensions import deprecated
|
||||
|
||||
logger = logging.getLogger("client.stdio.win32")
|
||||
|
||||
# Windows-specific imports for Job Objects
|
||||
if sys.platform == "win32":
|
||||
import pywintypes
|
||||
import win32api
|
||||
import win32con
|
||||
import win32job
|
||||
else:
|
||||
# Type stubs for non-Windows platforms
|
||||
win32api = None
|
||||
win32con = None
|
||||
win32job = None
|
||||
pywintypes = None
|
||||
|
||||
JobHandle = int
|
||||
|
||||
|
||||
def get_windows_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for Windows.
|
||||
|
||||
On Windows, commands might exist with specific extensions (.exe, .cmd, etc.)
|
||||
that need to be located for proper execution.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Windows-appropriate command path
|
||||
"""
|
||||
try:
|
||||
# First check if command exists in PATH as-is
|
||||
if command_path := shutil.which(command):
|
||||
return command_path
|
||||
|
||||
# Check for Windows-specific extensions
|
||||
for ext in [".cmd", ".bat", ".exe", ".ps1"]:
|
||||
ext_version = f"{command}{ext}"
|
||||
if ext_path := shutil.which(ext_version):
|
||||
return ext_path
|
||||
|
||||
# For regular commands or if we couldn't find special versions
|
||||
return command
|
||||
except OSError:
|
||||
# Handle file system errors during path resolution
|
||||
# (permissions, broken symlinks, etc.)
|
||||
return command
|
||||
|
||||
|
||||
class FallbackProcess:
|
||||
"""
|
||||
A fallback process wrapper for Windows to handle async I/O
|
||||
when using subprocess.Popen, which provides sync-only FileIO objects.
|
||||
|
||||
This wraps stdin and stdout into async-compatible
|
||||
streams (FileReadStream, FileWriteStream),
|
||||
so that MCP clients expecting async streams can work properly.
|
||||
"""
|
||||
|
||||
def __init__(self, popen_obj: subprocess.Popen[bytes]):
|
||||
self.popen: subprocess.Popen[bytes] = popen_obj
|
||||
self.stdin_raw = popen_obj.stdin # type: ignore[assignment]
|
||||
self.stdout_raw = popen_obj.stdout # type: ignore[assignment]
|
||||
self.stderr = popen_obj.stderr # type: ignore[assignment]
|
||||
|
||||
self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None
|
||||
self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Support async context manager entry."""
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: BaseException | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: object | None,
|
||||
) -> None:
|
||||
"""Terminate and wait on process exit inside a thread."""
|
||||
self.popen.terminate()
|
||||
await to_thread.run_sync(self.popen.wait)
|
||||
|
||||
# Close the file handles to prevent ResourceWarning
|
||||
if self.stdin:
|
||||
await self.stdin.aclose()
|
||||
if self.stdout:
|
||||
await self.stdout.aclose()
|
||||
if self.stdin_raw:
|
||||
self.stdin_raw.close()
|
||||
if self.stdout_raw:
|
||||
self.stdout_raw.close()
|
||||
if self.stderr:
|
||||
self.stderr.close()
|
||||
|
||||
async def wait(self):
|
||||
"""Async wait for process completion."""
|
||||
return await to_thread.run_sync(self.popen.wait)
|
||||
|
||||
def terminate(self):
|
||||
"""Terminate the subprocess immediately."""
|
||||
return self.popen.terminate()
|
||||
|
||||
def kill(self) -> None:
|
||||
"""Kill the subprocess immediately (alias for terminate)."""
|
||||
self.terminate()
|
||||
|
||||
@property
|
||||
def pid(self) -> int:
|
||||
"""Return the process ID."""
|
||||
return self.popen.pid
|
||||
|
||||
|
||||
# ------------------------
|
||||
# Updated function
|
||||
# ------------------------
|
||||
|
||||
|
||||
async def create_windows_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO | None = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
) -> Process | FallbackProcess:
|
||||
"""
|
||||
Creates a subprocess in a Windows-compatible way with Job Object support.
|
||||
|
||||
Attempt to use anyio's open_process for async subprocess creation.
|
||||
In some cases this will throw NotImplementedError on Windows, e.g.
|
||||
when using the SelectorEventLoop which does not support async subprocesses.
|
||||
In that case, we fall back to using subprocess.Popen.
|
||||
|
||||
The process is automatically added to a Job Object to ensure all child
|
||||
processes are terminated when the parent is terminated.
|
||||
|
||||
Args:
|
||||
command (str): The executable to run
|
||||
args (list[str]): List of command line arguments
|
||||
env (dict[str, str] | None): Environment variables
|
||||
errlog (TextIO | None): Where to send stderr output (defaults to sys.stderr)
|
||||
cwd (Path | str | None): Working directory for the subprocess
|
||||
|
||||
Returns:
|
||||
Process | FallbackProcess: Async-compatible subprocess with stdin and stdout streams
|
||||
"""
|
||||
job = _create_job_object()
|
||||
process = None
|
||||
|
||||
try:
|
||||
# First try using anyio with Windows-specific flags to hide console window
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
# Ensure we don't create console windows for each process
|
||||
creationflags=subprocess.CREATE_NO_WINDOW # type: ignore
|
||||
if hasattr(subprocess, "CREATE_NO_WINDOW")
|
||||
else 0,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
)
|
||||
except NotImplementedError:
|
||||
# If Windows doesn't support async subprocess creation, use fallback
|
||||
process = await _create_windows_fallback_process(command, args, env, errlog, cwd)
|
||||
except Exception:
|
||||
# Try again without creation flags
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
)
|
||||
|
||||
_maybe_assign_process_to_job(process, job)
|
||||
return process
|
||||
|
||||
|
||||
async def _create_windows_fallback_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO | None = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
) -> FallbackProcess:
|
||||
"""
|
||||
Create a subprocess using subprocess.Popen as a fallback when anyio fails.
|
||||
|
||||
This function wraps the sync subprocess.Popen in an async-compatible interface.
|
||||
"""
|
||||
try:
|
||||
# Try launching with creationflags to avoid opening a new console window
|
||||
popen_obj = subprocess.Popen(
|
||||
[command, *args],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=errlog,
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
bufsize=0, # Unbuffered output
|
||||
creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0),
|
||||
)
|
||||
except Exception:
|
||||
# If creationflags failed, fallback without them
|
||||
popen_obj = subprocess.Popen(
|
||||
[command, *args],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=errlog,
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
bufsize=0,
|
||||
)
|
||||
return FallbackProcess(popen_obj)
|
||||
|
||||
|
||||
def _create_job_object() -> int | None:
|
||||
"""
|
||||
Create a Windows Job Object configured to terminate all processes when closed.
|
||||
"""
|
||||
if sys.platform != "win32" or not win32job:
|
||||
return None
|
||||
|
||||
try:
|
||||
job = win32job.CreateJobObject(None, "")
|
||||
extended_info = win32job.QueryInformationJobObject(job, win32job.JobObjectExtendedLimitInformation)
|
||||
|
||||
extended_info["BasicLimitInformation"]["LimitFlags"] |= win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE
|
||||
win32job.SetInformationJobObject(job, win32job.JobObjectExtendedLimitInformation, extended_info)
|
||||
return job
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create Job Object for process tree management: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None:
|
||||
"""
|
||||
Try to assign a process to a job object. If assignment fails
|
||||
for any reason, the job handle is closed.
|
||||
"""
|
||||
if not job:
|
||||
return
|
||||
|
||||
if sys.platform != "win32" or not win32api or not win32con or not win32job:
|
||||
return
|
||||
|
||||
try:
|
||||
process_handle = win32api.OpenProcess(
|
||||
win32con.PROCESS_SET_QUOTA | win32con.PROCESS_TERMINATE, False, process.pid
|
||||
)
|
||||
if not process_handle:
|
||||
raise Exception("Failed to open process handle")
|
||||
|
||||
try:
|
||||
win32job.AssignProcessToJobObject(job, process_handle)
|
||||
process._job_object = job
|
||||
finally:
|
||||
win32api.CloseHandle(process_handle)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to assign process {process.pid} to Job Object: {e}")
|
||||
if win32api:
|
||||
win32api.CloseHandle(job)
|
||||
|
||||
|
||||
async def terminate_windows_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
|
||||
"""
|
||||
Terminate a process and all its children on Windows.
|
||||
|
||||
If the process has an associated job object, it will be terminated.
|
||||
Otherwise, falls back to basic process termination.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
|
||||
"""
|
||||
if sys.platform != "win32":
|
||||
return
|
||||
|
||||
job = getattr(process, "_job_object", None)
|
||||
if job and win32job:
|
||||
try:
|
||||
win32job.TerminateJobObject(job, 1)
|
||||
except Exception:
|
||||
# Job might already be terminated
|
||||
pass
|
||||
finally:
|
||||
if win32api:
|
||||
try:
|
||||
win32api.CloseHandle(job)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Always try to terminate the process itself as well
|
||||
try:
|
||||
process.terminate()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@deprecated(
|
||||
"terminate_windows_process is deprecated and will be removed in a future version. "
|
||||
"Process termination is now handled internally by the stdio_client context manager."
|
||||
)
|
||||
async def terminate_windows_process(process: Process | FallbackProcess):
|
||||
"""
|
||||
Terminate a Windows process.
|
||||
|
||||
Note: On Windows, terminating a process with process.terminate() doesn't
|
||||
always guarantee immediate process termination.
|
||||
So we give it 2s to exit, or we call process.kill()
|
||||
which sends a SIGKILL equivalent signal.
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
"""
|
||||
try:
|
||||
process.terminate()
|
||||
with anyio.fail_after(2.0):
|
||||
await process.wait()
|
||||
except TimeoutError:
|
||||
# Force kill if it doesn't terminate
|
||||
process.kill()
|
||||
0
.venv/lib/python3.11/site-packages/mcp/py.typed
Normal file
0
.venv/lib/python3.11/site-packages/mcp/py.typed
Normal file
@@ -0,0 +1,5 @@
|
||||
from .fastmcp import FastMCP
|
||||
from .lowlevel import NotificationOptions, Server
|
||||
from .models import InitializationOptions
|
||||
|
||||
__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"]
|
||||
50
.venv/lib/python3.11/site-packages/mcp/server/__main__.py
Normal file
50
.venv/lib/python3.11/site-packages/mcp/server/__main__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import anyio
|
||||
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.server.session import ServerSession
|
||||
from mcp.server.stdio import stdio_server
|
||||
from mcp.types import ServerCapabilities
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("server")
|
||||
|
||||
|
||||
async def receive_loop(session: ServerSession):
|
||||
logger.info("Starting receive loop")
|
||||
async for message in session.incoming_messages:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
continue
|
||||
|
||||
logger.info("Received message from client: %s", message)
|
||||
|
||||
|
||||
async def main():
|
||||
version = importlib.metadata.version("mcp")
|
||||
async with stdio_server() as (read_stream, write_stream):
|
||||
async with (
|
||||
ServerSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
InitializationOptions(
|
||||
server_name="mcp",
|
||||
server_version=version,
|
||||
capabilities=ServerCapabilities(),
|
||||
),
|
||||
) as session,
|
||||
write_stream,
|
||||
):
|
||||
await receive_loop(session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main, backend="trio")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
MCP OAuth server authorization components.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def stringify_pydantic_error(validation_error: ValidationError) -> str:
|
||||
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Request handlers for MCP authorization endpoints.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.datastructures import FormData, QueryParams
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import (
|
||||
AuthorizationErrorCode,
|
||||
AuthorizationParams,
|
||||
AuthorizeError,
|
||||
OAuthAuthorizationServerProvider,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
client_id: str = Field(..., description="The client ID")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
|
||||
|
||||
# see OAuthClientMetadata; we only support `code`
|
||||
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge")
|
||||
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
|
||||
state: str | None = Field(None, description="Optional state parameter")
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope; if specified, should be a space-separated list of scope strings",
|
||||
)
|
||||
resource: str | None = Field(
|
||||
None,
|
||||
description="RFC 8707 resource indicator - the MCP server this token will be used with",
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationErrorResponse(BaseModel):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None
|
||||
error_uri: AnyUrl | None = None
|
||||
# must be set if provided in the request
|
||||
state: str | None = None
|
||||
|
||||
|
||||
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
|
||||
if params is None: # pragma: no cover
|
||||
return None
|
||||
value = params.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class AnyUrlModel(RootModel[AnyUrl]):
|
||||
root: AnyUrl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# implements authorization requests for grant_type=code;
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
|
||||
state = None
|
||||
redirect_uri = None
|
||||
client = None
|
||||
params = None
|
||||
|
||||
async def error_response(
|
||||
error: AuthorizationErrorCode,
|
||||
error_description: str | None,
|
||||
attempt_load_client: bool = True,
|
||||
):
|
||||
# Error responses take two different formats:
|
||||
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
|
||||
# back to the redirect_uri with the error response fields as query
|
||||
# parameters. This allows the client to be notified of the error.
|
||||
# 2. Otherwise, we return an error response directly to the end user;
|
||||
# we choose to do so in JSON, but this is left undefined in the
|
||||
# specification.
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
|
||||
#
|
||||
# This logic is a bit awkward to handle, because the error might be thrown
|
||||
# very early in request validation, before we've done the usual Pydantic
|
||||
# validation, loaded the client, etc. To handle this, error_response()
|
||||
# contains fallback logic which attempts to load the parameters directly
|
||||
# from the request.
|
||||
|
||||
nonlocal client, redirect_uri, state
|
||||
if client is None and attempt_load_client:
|
||||
# make last-ditch attempt to load the client
|
||||
client_id = best_effort_extract_string("client_id", params)
|
||||
client = await self.provider.get_client(client_id) if client_id else None
|
||||
if redirect_uri is None and client:
|
||||
# make last-ditch effort to load the redirect uri
|
||||
try:
|
||||
if params is not None and "redirect_uri" not in params:
|
||||
raw_redirect_uri = None
|
||||
else:
|
||||
raw_redirect_uri = AnyUrlModel.model_validate(
|
||||
best_effort_extract_string("redirect_uri", params)
|
||||
).root
|
||||
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
|
||||
except (ValidationError, InvalidRedirectUriError):
|
||||
# if the redirect URI is invalid, ignore it & just return the
|
||||
# initial error
|
||||
pass
|
||||
|
||||
# the error response MUST contain the state specified by the client, if any
|
||||
if state is None: # pragma: no cover
|
||||
# make last-ditch effort to load state
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
error_resp = AuthorizationErrorResponse(
|
||||
error=error,
|
||||
error_description=error_description,
|
||||
state=state,
|
||||
)
|
||||
|
||||
if redirect_uri and client:
|
||||
return RedirectResponse(
|
||||
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
else:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=error_resp,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse request parameters
|
||||
if request.method == "GET":
|
||||
# Convert query_params to dict for pydantic validation
|
||||
params = request.query_params
|
||||
else:
|
||||
# Parse form data for POST requests
|
||||
params = await request.form()
|
||||
|
||||
# Save state if it exists, even before validation
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
try:
|
||||
auth_request = AuthorizationRequest.model_validate(params)
|
||||
state = auth_request.state # Update with validated state
|
||||
except ValidationError as validation_error:
|
||||
error: AuthorizationErrorCode = "invalid_request"
|
||||
for e in validation_error.errors():
|
||||
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
||||
error = "unsupported_response_type"
|
||||
break
|
||||
return await error_response(error, stringify_pydantic_error(validation_error))
|
||||
|
||||
# Get client information
|
||||
client = await self.provider.get_client(
|
||||
auth_request.client_id,
|
||||
)
|
||||
if not client:
|
||||
# For client_id validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=f"Client ID '{auth_request.client_id}' not found",
|
||||
attempt_load_client=False,
|
||||
)
|
||||
|
||||
# Validate redirect_uri against client's registered URIs
|
||||
try:
|
||||
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
|
||||
except InvalidRedirectUriError as validation_error:
|
||||
# For redirect_uri validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Validate scope - for scope errors, we can redirect
|
||||
try:
|
||||
scopes = client.validate_scope(auth_request.scope)
|
||||
except InvalidScopeError as validation_error:
|
||||
# For scope errors, redirect with error parameters
|
||||
return await error_response(
|
||||
error="invalid_scope",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Setup authorization parameters
|
||||
auth_params = AuthorizationParams(
|
||||
state=state,
|
||||
scopes=scopes,
|
||||
code_challenge=auth_request.code_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
|
||||
resource=auth_request.resource, # RFC 8707
|
||||
)
|
||||
|
||||
try:
|
||||
# Let the provider pick the next URI to redirect to
|
||||
return RedirectResponse(
|
||||
url=await self.provider.authorize(
|
||||
client,
|
||||
auth_params,
|
||||
),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
except AuthorizeError as e:
|
||||
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
||||
return await error_response(error=e.error, error_description=e.error_description)
|
||||
|
||||
except Exception as validation_error: # pragma: no cover
|
||||
# Catch-all for unexpected errors
|
||||
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
|
||||
return await error_response(error="server_error", error_description="An unexpected error occurred")
|
||||
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataHandler:
|
||||
metadata: OAuthMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtectedResourceMetadataHandler:
|
||||
metadata: ProtectedResourceMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
@@ -0,0 +1,136 @@
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationRequest(RootModel[OAuthClientMetadata]):
|
||||
# this wrapper is a no-op; it's just to separate out the types exposed to the
|
||||
# provider from what we use in the HTTP handler
|
||||
root: OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
options: ClientRegistrationOptions
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
|
||||
try:
|
||||
# Parse request body as JSON
|
||||
body = await request.json()
|
||||
client_metadata = OAuthClientMetadata.model_validate(body)
|
||||
|
||||
# Scope validation is handled below
|
||||
except ValidationError as validation_error:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id = str(uuid4())
|
||||
|
||||
# If auth method is None, default to client_secret_post
|
||||
if client_metadata.token_endpoint_auth_method is None:
|
||||
client_metadata.token_endpoint_auth_method = "client_secret_post"
|
||||
|
||||
client_secret = None
|
||||
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
|
||||
# cryptographically secure random 32-byte hex string
|
||||
client_secret = secrets.token_hex(32)
|
||||
|
||||
if client_metadata.scope is None and self.options.default_scopes is not None:
|
||||
client_metadata.scope = " ".join(self.options.default_scopes)
|
||||
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
|
||||
requested_scopes = set(client_metadata.scope.split())
|
||||
valid_scopes = set(self.options.valid_scopes)
|
||||
if not requested_scopes.issubset(valid_scopes): # pragma: no branch
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="Requested scopes are not valid: "
|
||||
f"{', '.join(requested_scopes - valid_scopes)}",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
if not {"authorization_code", "refresh_token"}.issubset(set(client_metadata.grant_types)):
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="grant_types must be authorization_code and refresh_token",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# The MCP spec requires servers to use the authorization `code` flow
|
||||
# with PKCE
|
||||
if "code" not in client_metadata.response_types:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="response_types must include 'code' for authorization_code grant",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id_issued_at = int(time.time())
|
||||
client_secret_expires_at = (
|
||||
client_id_issued_at + self.options.client_secret_expiry_seconds
|
||||
if self.options.client_secret_expiry_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
client_info = OAuthClientInformationFull(
|
||||
client_id=client_id,
|
||||
client_id_issued_at=client_id_issued_at,
|
||||
client_secret=client_secret,
|
||||
client_secret_expires_at=client_secret_expires_at,
|
||||
# passthrough information from the client request
|
||||
redirect_uris=client_metadata.redirect_uris,
|
||||
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
|
||||
grant_types=client_metadata.grant_types,
|
||||
response_types=client_metadata.response_types,
|
||||
client_name=client_metadata.client_name,
|
||||
client_uri=client_metadata.client_uri,
|
||||
logo_uri=client_metadata.logo_uri,
|
||||
scope=client_metadata.scope,
|
||||
contacts=client_metadata.contacts,
|
||||
tos_uri=client_metadata.tos_uri,
|
||||
policy_uri=client_metadata.policy_uri,
|
||||
jwks_uri=client_metadata.jwks_uri,
|
||||
jwks=client_metadata.jwks,
|
||||
software_id=client_metadata.software_id,
|
||||
software_version=client_metadata.software_version,
|
||||
)
|
||||
try:
|
||||
# Register client
|
||||
await self.provider.register_client(client_info)
|
||||
|
||||
# Return client information
|
||||
return PydanticJSONResponse(content=client_info, status_code=201)
|
||||
except RegistrationError as e:
|
||||
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
|
||||
status_code=400,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
|
||||
|
||||
|
||||
class RevocationRequest(BaseModel):
|
||||
"""
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
|
||||
"""
|
||||
|
||||
token: str
|
||||
token_type_hint: Literal["access_token", "refresh_token"] | None = None
|
||||
client_id: str
|
||||
client_secret: str | None
|
||||
|
||||
|
||||
class RevocationErrorResponse(BaseModel):
|
||||
error: Literal["invalid_request", "unauthorized_client"]
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevocationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
"""
|
||||
Handler for the OAuth 2.0 Token Revocation endpoint.
|
||||
"""
|
||||
try:
|
||||
client = await self.client_authenticator.authenticate_request(request)
|
||||
except AuthenticationError as e: # pragma: no cover
|
||||
return PydanticJSONResponse(
|
||||
status_code=401,
|
||||
content=RevocationErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
revocation_request = RevocationRequest.model_validate(dict(form_data))
|
||||
except ValidationError as e:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=RevocationErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(e),
|
||||
),
|
||||
)
|
||||
|
||||
loaders = [
|
||||
self.provider.load_access_token,
|
||||
partial(self.provider.load_refresh_token, client),
|
||||
]
|
||||
if revocation_request.token_type_hint == "refresh_token": # pragma: no cover
|
||||
loaders = reversed(loaders)
|
||||
|
||||
token: None | AccessToken | RefreshToken = None
|
||||
for loader in loaders:
|
||||
token = await loader(revocation_request.token)
|
||||
if token is not None:
|
||||
break
|
||||
|
||||
# if token is not found, just return HTTP 200 per the RFC
|
||||
if token and token.client_id == client.client_id:
|
||||
# Revoke token; provider is not meant to be able to do validation
|
||||
# at this point that would result in an error
|
||||
await self.provider.revoke_token(token)
|
||||
|
||||
# Return successful empty response
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,241 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
|
||||
class AuthorizationCodeRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(..., description="The authorization code")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
|
||||
code_verifier: str = Field(..., description="PKCE code verifier")
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str = Field(..., description="The refresh token")
|
||||
scope: str | None = Field(None, description="Optional scope parameter")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class TokenRequest(
|
||||
RootModel[
|
||||
Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
]
|
||||
):
|
||||
root: Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
|
||||
|
||||
class TokenErrorResponse(BaseModel):
|
||||
"""
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
"""
|
||||
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
error_uri: AnyHttpUrl | None = None
|
||||
|
||||
|
||||
class TokenSuccessResponse(RootModel[OAuthToken]):
|
||||
# this is just a wrapper over OAuthToken; the only reason we do this
|
||||
# is to have some separation between the HTTP response type, and the
|
||||
# type returned by the provider
|
||||
root: OAuthToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
||||
status_code = 200
|
||||
if isinstance(obj, TokenErrorResponse):
|
||||
status_code = 400
|
||||
|
||||
return PydanticJSONResponse(
|
||||
content=obj,
|
||||
status_code=status_code,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
async def handle(self, request: Request):
|
||||
try:
|
||||
client_info = await self.client_authenticator.authenticate_request(request)
|
||||
except AuthenticationError as e:
|
||||
# Authentication failures should return 401
|
||||
return PydanticJSONResponse(
|
||||
content=TokenErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
status_code=401,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
token_request = TokenRequest.model_validate(dict(form_data)).root
|
||||
except ValidationError as validation_error: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
)
|
||||
)
|
||||
|
||||
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="unsupported_grant_type",
|
||||
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
|
||||
)
|
||||
)
|
||||
|
||||
tokens: OAuthToken
|
||||
|
||||
match token_request:
|
||||
case AuthorizationCodeRequest():
|
||||
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
||||
if auth_code is None or auth_code.client_id != token_request.client_id:
|
||||
# if code belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
# make auth codes expire after a deadline
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
|
||||
if auth_code.expires_at < time.time():
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# verify redirect_uri doesn't change between /authorize and /tokens
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
|
||||
if auth_code.redirect_uri_provided_explicitly:
|
||||
authorize_request_redirect_uri = auth_code.redirect_uri
|
||||
else: # pragma: no cover
|
||||
authorize_request_redirect_uri = None
|
||||
|
||||
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
|
||||
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
|
||||
auth_redirect_str = (
|
||||
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
|
||||
)
|
||||
|
||||
if token_redirect_str != auth_redirect_str:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=("redirect_uri did not match the one used when creating auth code"),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify PKCE code verifier
|
||||
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
||||
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
|
||||
if hashed_code_verifier != auth_code.code_challenge:
|
||||
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="incorrect code_verifier",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange authorization code for tokens
|
||||
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
case RefreshTokenRequest(): # pragma: no cover
|
||||
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
||||
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
||||
# if token belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
if refresh_token.expires_at and refresh_token.expires_at < time.time():
|
||||
# if the refresh token has expired, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# Parse scopes if provided
|
||||
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
||||
|
||||
for scope in scopes:
|
||||
if scope not in refresh_token.scopes:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_scope",
|
||||
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange refresh token for new tokens
|
||||
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
return self.response(TokenSuccessResponse(root=tokens))
|
||||
@@ -0,0 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class PydanticJSONResponse(JSONResponse):
|
||||
# use pydantic json serialization instead of the stock `json.dumps`,
|
||||
# so that we can handle serializing pydantic models like AnyHttpUrl
|
||||
def render(self, content: Any) -> bytes:
|
||||
return content.model_dump_json(exclude_none=True).encode("utf-8")
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Middleware for MCP authorization.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
||||
import contextvars
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
# Create a contextvar to store the authenticated user
|
||||
# The default is None, indicating no authenticated user is present
|
||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
|
||||
|
||||
|
||||
def get_access_token() -> AccessToken | None:
|
||||
"""
|
||||
Get the access token from the current context.
|
||||
|
||||
Returns:
|
||||
The access token if an authenticated user is available, None otherwise.
|
||||
"""
|
||||
auth_user = auth_context_var.get()
|
||||
return auth_user.access_token if auth_user else None
|
||||
|
||||
|
||||
class AuthContextMiddleware:
|
||||
"""
|
||||
Middleware that extracts the authenticated user from the request
|
||||
and sets it in a contextvar for easy access throughout the request lifecycle.
|
||||
|
||||
This middleware should be added after the AuthenticationMiddleware in the
|
||||
middleware stack to ensure that the user is properly authenticated before
|
||||
being stored in the context.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
user = scope.get("user")
|
||||
if isinstance(user, AuthenticatedUser):
|
||||
# Set the authenticated user in the contextvar
|
||||
token = auth_context_var.set(user)
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
auth_context_var.reset(token)
|
||||
else:
|
||||
# No authenticated user, just process the request
|
||||
await self.app(scope, receive, send)
|
||||
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||
|
||||
|
||||
class AuthenticatedUser(SimpleUser):
|
||||
"""User with authentication info."""
|
||||
|
||||
def __init__(self, auth_info: AccessToken):
|
||||
super().__init__(auth_info.client_id)
|
||||
self.access_token = auth_info
|
||||
self.scopes = auth_info.scopes
|
||||
|
||||
|
||||
class BearerAuthBackend(AuthenticationBackend):
|
||||
"""
|
||||
Authentication backend that validates Bearer tokens using a TokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, token_verifier: TokenVerifier):
|
||||
self.token_verifier = token_verifier
|
||||
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
auth_header = next(
|
||||
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
|
||||
None,
|
||||
)
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Validate the token with the verifier
|
||||
auth_info = await self.token_verifier.verify_token(token)
|
||||
|
||||
if not auth_info:
|
||||
return None
|
||||
|
||||
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
|
||||
return None
|
||||
|
||||
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
|
||||
|
||||
|
||||
class RequireAuthMiddleware:
|
||||
"""
|
||||
Middleware that requires a valid Bearer token in the Authorization header.
|
||||
|
||||
This will validate the token with the auth provider and store the resulting
|
||||
auth info in the request state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: Any,
|
||||
required_scopes: list[str],
|
||||
resource_metadata_url: AnyHttpUrl | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
required_scopes: List of scopes that the token must have
|
||||
resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header
|
||||
"""
|
||||
self.app = app
|
||||
self.required_scopes = required_scopes
|
||||
self.resource_metadata_url = resource_metadata_url
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
auth_user = scope.get("user")
|
||||
if not isinstance(auth_user, AuthenticatedUser):
|
||||
await self._send_auth_error(
|
||||
send, status_code=401, error="invalid_token", description="Authentication required"
|
||||
)
|
||||
return
|
||||
|
||||
auth_credentials = scope.get("auth")
|
||||
|
||||
for required_scope in self.required_scopes:
|
||||
# auth_credentials should always be provided; this is just paranoia
|
||||
if auth_credentials is None or required_scope not in auth_credentials.scopes:
|
||||
await self._send_auth_error(
|
||||
send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}"
|
||||
)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None:
|
||||
"""Send an authentication error response with WWW-Authenticate header."""
|
||||
# Build WWW-Authenticate header value
|
||||
www_auth_parts = [f'error="{error}"', f'error_description="{description}"']
|
||||
if self.resource_metadata_url: # pragma: no cover
|
||||
www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"')
|
||||
|
||||
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
|
||||
|
||||
# Send response
|
||||
body = {"error": error, "error_description": description}
|
||||
body_bytes = json.dumps(body).encode()
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"content-length", str(len(body_bytes)).encode()),
|
||||
(b"www-authenticate", www_authenticate.encode()),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": body_bytes,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
import base64
|
||||
import binascii
|
||||
import hmac
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message # pragma: no cover
|
||||
|
||||
|
||||
class ClientAuthenticator:
|
||||
"""
|
||||
ClientAuthenticator is a callable which validates requests from a client
|
||||
application, used to verify /token calls.
|
||||
If, during registration, the client requested to be issued a secret, the
|
||||
authenticator asserts that /token calls must be authenticated with
|
||||
that same token.
|
||||
NOTE: clients can opt for no authentication during registration, in which case this
|
||||
logic is skipped.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""
|
||||
Initialize the dependency.
|
||||
|
||||
Args:
|
||||
provider: Provider to look up client information
|
||||
"""
|
||||
self.provider = provider
|
||||
|
||||
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
|
||||
"""
|
||||
Authenticate a client from an HTTP request.
|
||||
|
||||
Extracts client credentials from the appropriate location based on the
|
||||
client's registered authentication method and validates them.
|
||||
|
||||
Args:
|
||||
request: The HTTP request containing client credentials
|
||||
|
||||
Returns:
|
||||
The authenticated client information
|
||||
|
||||
Raises:
|
||||
AuthenticationError: If authentication fails
|
||||
"""
|
||||
form_data = await request.form()
|
||||
client_id = form_data.get("client_id")
|
||||
if not client_id:
|
||||
raise AuthenticationError("Missing client_id")
|
||||
|
||||
client = await self.provider.get_client(str(client_id))
|
||||
if not client:
|
||||
raise AuthenticationError("Invalid client_id") # pragma: no cover
|
||||
|
||||
request_client_secret: str | None = None
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
if client.token_endpoint_auth_method == "client_secret_basic":
|
||||
if not auth_header.startswith("Basic "):
|
||||
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
|
||||
|
||||
try:
|
||||
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
|
||||
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
|
||||
if ":" not in decoded:
|
||||
raise ValueError("Invalid Basic auth format")
|
||||
basic_client_id, request_client_secret = decoded.split(":", 1)
|
||||
|
||||
# URL-decode both parts per RFC 6749 Section 2.3.1
|
||||
basic_client_id = unquote(basic_client_id)
|
||||
request_client_secret = unquote(request_client_secret)
|
||||
|
||||
if basic_client_id != client_id:
|
||||
raise AuthenticationError("Client ID mismatch in Basic auth")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise AuthenticationError("Invalid Basic authentication header")
|
||||
|
||||
elif client.token_endpoint_auth_method == "client_secret_post":
|
||||
raw_form_data = form_data.get("client_secret")
|
||||
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
|
||||
if isinstance(raw_form_data, str):
|
||||
request_client_secret = str(raw_form_data)
|
||||
|
||||
elif client.token_endpoint_auth_method == "none":
|
||||
request_client_secret = None
|
||||
else:
|
||||
raise AuthenticationError( # pragma: no cover
|
||||
f"Unsupported auth method: {client.token_endpoint_auth_method}"
|
||||
)
|
||||
|
||||
# If client from the store expects a secret, validate that the request provides
|
||||
# that secret
|
||||
if client.client_secret: # pragma: no branch
|
||||
if not request_client_secret:
|
||||
raise AuthenticationError("Client secret is required") # pragma: no cover
|
||||
|
||||
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
|
||||
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
|
||||
# arguments to bytes.
|
||||
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
|
||||
raise AuthenticationError("Invalid client_secret") # pragma: no cover
|
||||
|
||||
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
|
||||
raise AuthenticationError("Client secret has expired") # pragma: no cover
|
||||
|
||||
return client
|
||||
301
.venv/lib/python3.11/site-packages/mcp/server/auth/provider.py
Normal file
301
.venv/lib/python3.11/site-packages/mcp/server/auth/provider.py
Normal file
@@ -0,0 +1,301 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Literal, Protocol, TypeVar
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
|
||||
class AuthorizationParams(BaseModel):
|
||||
state: str | None
|
||||
scopes: list[str] | None
|
||||
code_challenge: str
|
||||
redirect_uri: AnyUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
class AuthorizationCode(BaseModel):
|
||||
code: str
|
||||
scopes: list[str]
|
||||
expires_at: float
|
||||
client_id: str
|
||||
code_challenge: str
|
||||
redirect_uri: AnyUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
class RefreshToken(BaseModel):
|
||||
token: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
expires_at: int | None = None
|
||||
|
||||
|
||||
class AccessToken(BaseModel):
|
||||
token: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
expires_at: int | None = None
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
RegistrationErrorCode = Literal[
|
||||
"invalid_redirect_uri",
|
||||
"invalid_client_metadata",
|
||||
"invalid_software_statement",
|
||||
"unapproved_software_statement",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistrationError(Exception):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
AuthorizationErrorCode = Literal[
|
||||
"invalid_request",
|
||||
"unauthorized_client",
|
||||
"access_denied",
|
||||
"unsupported_response_type",
|
||||
"invalid_scope",
|
||||
"server_error",
|
||||
"temporarily_unavailable",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthorizeError(Exception):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
TokenErrorCode = Literal[
|
||||
"invalid_request",
|
||||
"invalid_client",
|
||||
"invalid_grant",
|
||||
"unauthorized_client",
|
||||
"unsupported_grant_type",
|
||||
"invalid_scope",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenError(Exception):
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
class TokenVerifier(Protocol):
|
||||
"""Protocol for verifying bearer tokens."""
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify a bearer token and return access info if valid."""
|
||||
|
||||
|
||||
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
|
||||
# OK to add fields to subclasses which should not be exposed externally.
|
||||
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
|
||||
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
|
||||
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
|
||||
|
||||
|
||||
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""
|
||||
Retrieves client information by client ID.
|
||||
|
||||
Implementors MAY raise NotImplementedError if dynamic client registration is
|
||||
disabled in ClientRegistrationOptions.
|
||||
|
||||
Args:
|
||||
client_id: The ID of the client to retrieve.
|
||||
|
||||
Returns:
|
||||
The client information, or None if the client does not exist.
|
||||
"""
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""
|
||||
Saves client information as part of registering it.
|
||||
|
||||
Implementors MAY raise NotImplementedError if dynamic client registration is
|
||||
disabled in ClientRegistrationOptions.
|
||||
|
||||
Args:
|
||||
client_info: The client metadata to register.
|
||||
|
||||
Raises:
|
||||
RegistrationError: If the client metadata is invalid.
|
||||
"""
|
||||
|
||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||
"""
|
||||
Called as part of the /authorize endpoint, and returns a URL that the client
|
||||
will be redirected to.
|
||||
Many MCP implementations will redirect to a third-party provider to perform
|
||||
a second OAuth exchange with that provider. In this sort of setup, the client
|
||||
has an OAuth connection with the MCP server, and the MCP server has an OAuth
|
||||
connection with the 3rd-party provider. At the end of this flow, the client
|
||||
should be redirected to the redirect_uri from params.redirect_uri.
|
||||
|
||||
+--------+ +------------+ +-------------------+
|
||||
| | | | | |
|
||||
| Client | --> | MCP Server | --> | 3rd Party OAuth |
|
||||
| | | | | Server |
|
||||
+--------+ +------------+ +-------------------+
|
||||
| ^ |
|
||||
+------------+ | | |
|
||||
| | | | Redirect |
|
||||
|redirect_uri|<-----+ +------------------+
|
||||
| |
|
||||
+------------+
|
||||
|
||||
Implementations will need to define another handler on the MCP server return
|
||||
flow to perform the second redirect, and generate and store an authorization
|
||||
code as part of completing the OAuth authorization step.
|
||||
|
||||
Implementations SHOULD generate an authorization code with at least 160 bits of
|
||||
entropy,
|
||||
and MUST generate an authorization code with at least 128 bits of entropy.
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10.
|
||||
|
||||
Args:
|
||||
client: The client requesting authorization.
|
||||
params: The parameters of the authorization request.
|
||||
|
||||
Returns:
|
||||
A URL to redirect the client to for authorization.
|
||||
|
||||
Raises:
|
||||
AuthorizeError: If the authorization request is invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
) -> AuthorizationCodeT | None:
|
||||
"""
|
||||
Loads an AuthorizationCode by its code.
|
||||
|
||||
Args:
|
||||
client: The client that requested the authorization code.
|
||||
authorization_code: The authorization code to get the challenge for.
|
||||
|
||||
Returns:
|
||||
The AuthorizationCode, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
|
||||
) -> OAuthToken:
|
||||
"""
|
||||
Exchanges an authorization code for an access token and refresh token.
|
||||
|
||||
Args:
|
||||
client: The client exchanging the authorization code.
|
||||
authorization_code: The authorization code to exchange.
|
||||
|
||||
Returns:
|
||||
The OAuth token, containing access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
TokenError: If the request is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
|
||||
"""
|
||||
Loads a RefreshToken by its token string.
|
||||
|
||||
Args:
|
||||
client: The client that is requesting to load the refresh token.
|
||||
refresh_token: The refresh token string to load.
|
||||
|
||||
Returns:
|
||||
The RefreshToken object if found, or None if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self,
|
||||
client: OAuthClientInformationFull,
|
||||
refresh_token: RefreshTokenT,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
"""
|
||||
Exchanges a refresh token for an access token and refresh token.
|
||||
|
||||
Implementations SHOULD rotate both the access token and refresh token.
|
||||
|
||||
Args:
|
||||
client: The client exchanging the refresh token.
|
||||
refresh_token: The refresh token to exchange.
|
||||
scopes: Optional scopes to request with the new access token.
|
||||
|
||||
Returns:
|
||||
The OAuth token, containing access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
TokenError: If the request is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_access_token(self, token: str) -> AccessTokenT | None:
|
||||
"""
|
||||
Loads an access token by its token.
|
||||
|
||||
Args:
|
||||
token: The access token to verify.
|
||||
|
||||
Returns:
|
||||
The AuthInfo, or None if the token is invalid.
|
||||
"""
|
||||
|
||||
async def revoke_token(
|
||||
self,
|
||||
token: AccessTokenT | RefreshTokenT,
|
||||
) -> None:
|
||||
"""
|
||||
Revokes an access or refresh token.
|
||||
|
||||
If the given token is invalid or already revoked, this method should do nothing.
|
||||
|
||||
Implementations SHOULD revoke both the access token and its corresponding
|
||||
refresh token, regardless of which of the access token or refresh token is
|
||||
provided.
|
||||
|
||||
Args:
|
||||
token: the token to revoke
|
||||
"""
|
||||
|
||||
|
||||
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
|
||||
parsed_uri = urlparse(redirect_uri_base)
|
||||
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs]
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
query_params.append((k, v))
|
||||
|
||||
redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params)))
|
||||
return redirect_uri
|
||||
|
||||
|
||||
class ProviderTokenVerifier(TokenVerifier):
|
||||
"""Token verifier that uses an OAuthAuthorizationServerProvider.
|
||||
|
||||
This is provided for backwards compatibility with existing auth_server_provider
|
||||
configurations. For new implementations using AS/RS separation, consider using
|
||||
the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"):
|
||||
self.provider = provider
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify token using the provider's load_access_token method."""
|
||||
return await self.provider.load_access_token(token)
|
||||
253
.venv/lib/python3.11/site-packages/mcp/server/auth/routes.py
Normal file
253
.venv/lib/python3.11/site-packages/mcp/server/auth/routes.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Route, request_response # type: ignore
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationHandler
|
||||
from mcp.server.auth.handlers.metadata import MetadataHandler
|
||||
from mcp.server.auth.handlers.register import RegistrationHandler
|
||||
from mcp.server.auth.handlers.revoke import RevocationHandler
|
||||
from mcp.server.auth.handlers.token import TokenHandler
|
||||
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
||||
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
|
||||
from mcp.shared.auth import OAuthMetadata
|
||||
|
||||
|
||||
def validate_issuer_url(url: AnyHttpUrl):
|
||||
"""
|
||||
Validate that the issuer URL meets OAuth 2.0 requirements.
|
||||
|
||||
Args:
|
||||
url: The issuer URL to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If the issuer URL is invalid
|
||||
"""
|
||||
|
||||
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
|
||||
if (
|
||||
url.scheme != "https"
|
||||
and url.host != "localhost"
|
||||
and (url.host is not None and not url.host.startswith("127.0.0.1"))
|
||||
):
|
||||
raise ValueError("Issuer URL must be HTTPS") # pragma: no cover
|
||||
|
||||
# No fragments or query parameters allowed
|
||||
if url.fragment:
|
||||
raise ValueError("Issuer URL must not have a fragment") # pragma: no cover
|
||||
if url.query:
|
||||
raise ValueError("Issuer URL must not have a query string") # pragma: no cover
|
||||
|
||||
|
||||
AUTHORIZATION_PATH = "/authorize"
|
||||
TOKEN_PATH = "/token"
|
||||
REGISTRATION_PATH = "/register"
|
||||
REVOCATION_PATH = "/revoke"
|
||||
|
||||
|
||||
def cors_middleware(
|
||||
handler: Callable[[Request], Response | Awaitable[Response]],
|
||||
allow_methods: list[str],
|
||||
) -> ASGIApp:
|
||||
cors_app = CORSMiddleware(
|
||||
app=request_response(handler),
|
||||
allow_origins="*",
|
||||
allow_methods=allow_methods,
|
||||
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
|
||||
)
|
||||
return cors_app
|
||||
|
||||
|
||||
def create_auth_routes(
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
issuer_url: AnyHttpUrl,
|
||||
service_documentation_url: AnyHttpUrl | None = None,
|
||||
client_registration_options: ClientRegistrationOptions | None = None,
|
||||
revocation_options: RevocationOptions | None = None,
|
||||
) -> list[Route]:
|
||||
validate_issuer_url(issuer_url)
|
||||
|
||||
client_registration_options = client_registration_options or ClientRegistrationOptions()
|
||||
revocation_options = revocation_options or RevocationOptions()
|
||||
metadata = build_metadata(
|
||||
issuer_url,
|
||||
service_documentation_url,
|
||||
client_registration_options,
|
||||
revocation_options,
|
||||
)
|
||||
client_authenticator = ClientAuthenticator(provider)
|
||||
|
||||
# Create routes
|
||||
# Allow CORS requests for endpoints meant to be hit by the OAuth client
|
||||
# (with the client secret). This is intended to support things like MCP Inspector,
|
||||
# where the client runs in a web browser.
|
||||
routes = [
|
||||
Route(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
endpoint=cors_middleware(
|
||||
MetadataHandler(metadata).handle,
|
||||
["GET", "OPTIONS"],
|
||||
),
|
||||
methods=["GET", "OPTIONS"],
|
||||
),
|
||||
Route(
|
||||
AUTHORIZATION_PATH,
|
||||
# do not allow CORS for authorization endpoint;
|
||||
# clients should just redirect to this
|
||||
endpoint=AuthorizationHandler(provider).handle,
|
||||
methods=["GET", "POST"],
|
||||
),
|
||||
Route(
|
||||
TOKEN_PATH,
|
||||
endpoint=cors_middleware(
|
||||
TokenHandler(provider, client_authenticator).handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
),
|
||||
]
|
||||
|
||||
if client_registration_options.enabled: # pragma: no branch
|
||||
registration_handler = RegistrationHandler(
|
||||
provider,
|
||||
options=client_registration_options,
|
||||
)
|
||||
routes.append(
|
||||
Route(
|
||||
REGISTRATION_PATH,
|
||||
endpoint=cors_middleware(
|
||||
registration_handler.handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
if revocation_options.enabled: # pragma: no branch
|
||||
revocation_handler = RevocationHandler(provider, client_authenticator)
|
||||
routes.append(
|
||||
Route(
|
||||
REVOCATION_PATH,
|
||||
endpoint=cors_middleware(
|
||||
revocation_handler.handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
||||
|
||||
def build_metadata(
|
||||
issuer_url: AnyHttpUrl,
|
||||
service_documentation_url: AnyHttpUrl | None,
|
||||
client_registration_options: ClientRegistrationOptions,
|
||||
revocation_options: RevocationOptions,
|
||||
) -> OAuthMetadata:
|
||||
authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH)
|
||||
token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH)
|
||||
|
||||
# Create metadata
|
||||
metadata = OAuthMetadata(
|
||||
issuer=issuer_url,
|
||||
authorization_endpoint=authorization_url,
|
||||
token_endpoint=token_url,
|
||||
scopes_supported=client_registration_options.valid_scopes,
|
||||
response_types_supported=["code"],
|
||||
response_modes_supported=None,
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
|
||||
token_endpoint_auth_signing_alg_values_supported=None,
|
||||
service_documentation=service_documentation_url,
|
||||
ui_locales_supported=None,
|
||||
op_policy_uri=None,
|
||||
op_tos_uri=None,
|
||||
introspection_endpoint=None,
|
||||
code_challenge_methods_supported=["S256"],
|
||||
)
|
||||
|
||||
# Add registration endpoint if supported
|
||||
if client_registration_options.enabled: # pragma: no branch
|
||||
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
|
||||
|
||||
# Add revocation endpoint if supported
|
||||
if revocation_options.enabled: # pragma: no branch
|
||||
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
|
||||
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def build_resource_metadata_url(resource_server_url: AnyHttpUrl) -> AnyHttpUrl:
|
||||
"""
|
||||
Build RFC 9728 compliant protected resource metadata URL.
|
||||
|
||||
Inserts /.well-known/oauth-protected-resource between host and resource path
|
||||
as specified in RFC 9728 §3.1.
|
||||
|
||||
Args:
|
||||
resource_server_url: The resource server URL (e.g., https://example.com/mcp)
|
||||
|
||||
Returns:
|
||||
The metadata URL (e.g., https://example.com/.well-known/oauth-protected-resource/mcp)
|
||||
"""
|
||||
parsed = urlparse(str(resource_server_url))
|
||||
# Handle trailing slash: if path is just "/", treat as empty
|
||||
resource_path = parsed.path if parsed.path != "/" else ""
|
||||
return AnyHttpUrl(f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-protected-resource{resource_path}")
|
||||
|
||||
|
||||
def create_protected_resource_routes(
|
||||
resource_url: AnyHttpUrl,
|
||||
authorization_servers: list[AnyHttpUrl],
|
||||
scopes_supported: list[str] | None = None,
|
||||
resource_name: str | None = None,
|
||||
resource_documentation: AnyHttpUrl | None = None,
|
||||
) -> list[Route]:
|
||||
"""
|
||||
Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728).
|
||||
|
||||
Args:
|
||||
resource_url: The URL of this resource server
|
||||
authorization_servers: List of authorization servers that can issue tokens
|
||||
scopes_supported: Optional list of scopes supported by this resource
|
||||
|
||||
Returns:
|
||||
List of Starlette routes for protected resource metadata
|
||||
"""
|
||||
from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler
|
||||
from mcp.shared.auth import ProtectedResourceMetadata
|
||||
|
||||
metadata = ProtectedResourceMetadata(
|
||||
resource=resource_url,
|
||||
authorization_servers=authorization_servers,
|
||||
scopes_supported=scopes_supported,
|
||||
resource_name=resource_name,
|
||||
resource_documentation=resource_documentation,
|
||||
# bearer_methods_supported defaults to ["header"] in the model
|
||||
)
|
||||
|
||||
handler = ProtectedResourceMetadataHandler(metadata)
|
||||
|
||||
# RFC 9728 §3.1: Register route at /.well-known/oauth-protected-resource + resource path
|
||||
metadata_url = build_resource_metadata_url(resource_url)
|
||||
# Extract just the path part for route registration
|
||||
parsed = urlparse(str(metadata_url))
|
||||
well_known_path = parsed.path
|
||||
|
||||
return [
|
||||
Route(
|
||||
well_known_path,
|
||||
endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]),
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
|
||||
class ClientRegistrationOptions(BaseModel):
|
||||
enabled: bool = False
|
||||
client_secret_expiry_seconds: int | None = None
|
||||
valid_scopes: list[str] | None = None
|
||||
default_scopes: list[str] | None = None
|
||||
|
||||
|
||||
class RevocationOptions(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
class AuthSettings(BaseModel):
|
||||
issuer_url: AnyHttpUrl = Field(
|
||||
...,
|
||||
description="OAuth authorization server URL that issues tokens for this resource server.",
|
||||
)
|
||||
service_documentation_url: AnyHttpUrl | None = None
|
||||
client_registration_options: ClientRegistrationOptions | None = None
|
||||
revocation_options: RevocationOptions | None = None
|
||||
required_scopes: list[str] | None = None
|
||||
|
||||
# Resource Server settings (when operating as RS only)
|
||||
resource_server_url: AnyHttpUrl | None = Field(
|
||||
...,
|
||||
description="The URL of the MCP server to be used as the resource identifier "
|
||||
"and base route to look up OAuth Protected Resource Metadata.",
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user