Fix project isolation: Make loadChatHistory respect active project sessions
- Modified loadChatHistory() to check for active project before fetching all sessions - When active project exists, use project.sessions instead of fetching from API - Added detailed console logging to debug session filtering - This prevents ALL sessions from appearing in every project's sidebar Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Request handlers for MCP authorization endpoints.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.datastructures import FormData, QueryParams
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import (
|
||||
AuthorizationErrorCode,
|
||||
AuthorizationParams,
|
||||
AuthorizeError,
|
||||
OAuthAuthorizationServerProvider,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
client_id: str = Field(..., description="The client ID")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
|
||||
|
||||
# see OAuthClientMetadata; we only support `code`
|
||||
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge")
|
||||
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
|
||||
state: str | None = Field(None, description="Optional state parameter")
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope; if specified, should be a space-separated list of scope strings",
|
||||
)
|
||||
resource: str | None = Field(
|
||||
None,
|
||||
description="RFC 8707 resource indicator - the MCP server this token will be used with",
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationErrorResponse(BaseModel):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None
|
||||
error_uri: AnyUrl | None = None
|
||||
# must be set if provided in the request
|
||||
state: str | None = None
|
||||
|
||||
|
||||
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
|
||||
if params is None: # pragma: no cover
|
||||
return None
|
||||
value = params.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class AnyUrlModel(RootModel[AnyUrl]):
|
||||
root: AnyUrl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# implements authorization requests for grant_type=code;
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
|
||||
state = None
|
||||
redirect_uri = None
|
||||
client = None
|
||||
params = None
|
||||
|
||||
async def error_response(
|
||||
error: AuthorizationErrorCode,
|
||||
error_description: str | None,
|
||||
attempt_load_client: bool = True,
|
||||
):
|
||||
# Error responses take two different formats:
|
||||
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
|
||||
# back to the redirect_uri with the error response fields as query
|
||||
# parameters. This allows the client to be notified of the error.
|
||||
# 2. Otherwise, we return an error response directly to the end user;
|
||||
# we choose to do so in JSON, but this is left undefined in the
|
||||
# specification.
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
|
||||
#
|
||||
# This logic is a bit awkward to handle, because the error might be thrown
|
||||
# very early in request validation, before we've done the usual Pydantic
|
||||
# validation, loaded the client, etc. To handle this, error_response()
|
||||
# contains fallback logic which attempts to load the parameters directly
|
||||
# from the request.
|
||||
|
||||
nonlocal client, redirect_uri, state
|
||||
if client is None and attempt_load_client:
|
||||
# make last-ditch attempt to load the client
|
||||
client_id = best_effort_extract_string("client_id", params)
|
||||
client = await self.provider.get_client(client_id) if client_id else None
|
||||
if redirect_uri is None and client:
|
||||
# make last-ditch effort to load the redirect uri
|
||||
try:
|
||||
if params is not None and "redirect_uri" not in params:
|
||||
raw_redirect_uri = None
|
||||
else:
|
||||
raw_redirect_uri = AnyUrlModel.model_validate(
|
||||
best_effort_extract_string("redirect_uri", params)
|
||||
).root
|
||||
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
|
||||
except (ValidationError, InvalidRedirectUriError):
|
||||
# if the redirect URI is invalid, ignore it & just return the
|
||||
# initial error
|
||||
pass
|
||||
|
||||
# the error response MUST contain the state specified by the client, if any
|
||||
if state is None: # pragma: no cover
|
||||
# make last-ditch effort to load state
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
error_resp = AuthorizationErrorResponse(
|
||||
error=error,
|
||||
error_description=error_description,
|
||||
state=state,
|
||||
)
|
||||
|
||||
if redirect_uri and client:
|
||||
return RedirectResponse(
|
||||
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
else:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=error_resp,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse request parameters
|
||||
if request.method == "GET":
|
||||
# Convert query_params to dict for pydantic validation
|
||||
params = request.query_params
|
||||
else:
|
||||
# Parse form data for POST requests
|
||||
params = await request.form()
|
||||
|
||||
# Save state if it exists, even before validation
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
try:
|
||||
auth_request = AuthorizationRequest.model_validate(params)
|
||||
state = auth_request.state # Update with validated state
|
||||
except ValidationError as validation_error:
|
||||
error: AuthorizationErrorCode = "invalid_request"
|
||||
for e in validation_error.errors():
|
||||
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
||||
error = "unsupported_response_type"
|
||||
break
|
||||
return await error_response(error, stringify_pydantic_error(validation_error))
|
||||
|
||||
# Get client information
|
||||
client = await self.provider.get_client(
|
||||
auth_request.client_id,
|
||||
)
|
||||
if not client:
|
||||
# For client_id validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=f"Client ID '{auth_request.client_id}' not found",
|
||||
attempt_load_client=False,
|
||||
)
|
||||
|
||||
# Validate redirect_uri against client's registered URIs
|
||||
try:
|
||||
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
|
||||
except InvalidRedirectUriError as validation_error:
|
||||
# For redirect_uri validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Validate scope - for scope errors, we can redirect
|
||||
try:
|
||||
scopes = client.validate_scope(auth_request.scope)
|
||||
except InvalidScopeError as validation_error:
|
||||
# For scope errors, redirect with error parameters
|
||||
return await error_response(
|
||||
error="invalid_scope",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Setup authorization parameters
|
||||
auth_params = AuthorizationParams(
|
||||
state=state,
|
||||
scopes=scopes,
|
||||
code_challenge=auth_request.code_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
|
||||
resource=auth_request.resource, # RFC 8707
|
||||
)
|
||||
|
||||
try:
|
||||
# Let the provider pick the next URI to redirect to
|
||||
return RedirectResponse(
|
||||
url=await self.provider.authorize(
|
||||
client,
|
||||
auth_params,
|
||||
),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
except AuthorizeError as e:
|
||||
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
||||
return await error_response(error=e.error, error_description=e.error_description)
|
||||
|
||||
except Exception as validation_error: # pragma: no cover
|
||||
# Catch-all for unexpected errors
|
||||
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
|
||||
return await error_response(error="server_error", error_description="An unexpected error occurred")
|
||||
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataHandler:
|
||||
metadata: OAuthMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtectedResourceMetadataHandler:
|
||||
metadata: ProtectedResourceMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
@@ -0,0 +1,136 @@
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationRequest(RootModel[OAuthClientMetadata]):
|
||||
# this wrapper is a no-op; it's just to separate out the types exposed to the
|
||||
# provider from what we use in the HTTP handler
|
||||
root: OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
options: ClientRegistrationOptions
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
|
||||
try:
|
||||
# Parse request body as JSON
|
||||
body = await request.json()
|
||||
client_metadata = OAuthClientMetadata.model_validate(body)
|
||||
|
||||
# Scope validation is handled below
|
||||
except ValidationError as validation_error:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id = str(uuid4())
|
||||
|
||||
# If auth method is None, default to client_secret_post
|
||||
if client_metadata.token_endpoint_auth_method is None:
|
||||
client_metadata.token_endpoint_auth_method = "client_secret_post"
|
||||
|
||||
client_secret = None
|
||||
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
|
||||
# cryptographically secure random 32-byte hex string
|
||||
client_secret = secrets.token_hex(32)
|
||||
|
||||
if client_metadata.scope is None and self.options.default_scopes is not None:
|
||||
client_metadata.scope = " ".join(self.options.default_scopes)
|
||||
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
|
||||
requested_scopes = set(client_metadata.scope.split())
|
||||
valid_scopes = set(self.options.valid_scopes)
|
||||
if not requested_scopes.issubset(valid_scopes): # pragma: no branch
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="Requested scopes are not valid: "
|
||||
f"{', '.join(requested_scopes - valid_scopes)}",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
if not {"authorization_code", "refresh_token"}.issubset(set(client_metadata.grant_types)):
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="grant_types must be authorization_code and refresh_token",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# The MCP spec requires servers to use the authorization `code` flow
|
||||
# with PKCE
|
||||
if "code" not in client_metadata.response_types:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="response_types must include 'code' for authorization_code grant",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id_issued_at = int(time.time())
|
||||
client_secret_expires_at = (
|
||||
client_id_issued_at + self.options.client_secret_expiry_seconds
|
||||
if self.options.client_secret_expiry_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
client_info = OAuthClientInformationFull(
|
||||
client_id=client_id,
|
||||
client_id_issued_at=client_id_issued_at,
|
||||
client_secret=client_secret,
|
||||
client_secret_expires_at=client_secret_expires_at,
|
||||
# passthrough information from the client request
|
||||
redirect_uris=client_metadata.redirect_uris,
|
||||
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
|
||||
grant_types=client_metadata.grant_types,
|
||||
response_types=client_metadata.response_types,
|
||||
client_name=client_metadata.client_name,
|
||||
client_uri=client_metadata.client_uri,
|
||||
logo_uri=client_metadata.logo_uri,
|
||||
scope=client_metadata.scope,
|
||||
contacts=client_metadata.contacts,
|
||||
tos_uri=client_metadata.tos_uri,
|
||||
policy_uri=client_metadata.policy_uri,
|
||||
jwks_uri=client_metadata.jwks_uri,
|
||||
jwks=client_metadata.jwks,
|
||||
software_id=client_metadata.software_id,
|
||||
software_version=client_metadata.software_version,
|
||||
)
|
||||
try:
|
||||
# Register client
|
||||
await self.provider.register_client(client_info)
|
||||
|
||||
# Return client information
|
||||
return PydanticJSONResponse(content=client_info, status_code=201)
|
||||
except RegistrationError as e:
|
||||
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
|
||||
status_code=400,
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
|
||||
|
||||
|
||||
class RevocationRequest(BaseModel):
|
||||
"""
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
|
||||
"""
|
||||
|
||||
token: str
|
||||
token_type_hint: Literal["access_token", "refresh_token"] | None = None
|
||||
client_id: str
|
||||
client_secret: str | None
|
||||
|
||||
|
||||
class RevocationErrorResponse(BaseModel):
|
||||
error: Literal["invalid_request", "unauthorized_client"]
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevocationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
"""
|
||||
Handler for the OAuth 2.0 Token Revocation endpoint.
|
||||
"""
|
||||
try:
|
||||
client = await self.client_authenticator.authenticate_request(request)
|
||||
except AuthenticationError as e: # pragma: no cover
|
||||
return PydanticJSONResponse(
|
||||
status_code=401,
|
||||
content=RevocationErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
revocation_request = RevocationRequest.model_validate(dict(form_data))
|
||||
except ValidationError as e:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=RevocationErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(e),
|
||||
),
|
||||
)
|
||||
|
||||
loaders = [
|
||||
self.provider.load_access_token,
|
||||
partial(self.provider.load_refresh_token, client),
|
||||
]
|
||||
if revocation_request.token_type_hint == "refresh_token": # pragma: no cover
|
||||
loaders = reversed(loaders)
|
||||
|
||||
token: None | AccessToken | RefreshToken = None
|
||||
for loader in loaders:
|
||||
token = await loader(revocation_request.token)
|
||||
if token is not None:
|
||||
break
|
||||
|
||||
# if token is not found, just return HTTP 200 per the RFC
|
||||
if token and token.client_id == client.client_id:
|
||||
# Revoke token; provider is not meant to be able to do validation
|
||||
# at this point that would result in an error
|
||||
await self.provider.revoke_token(token)
|
||||
|
||||
# Return successful empty response
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,241 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
|
||||
class AuthorizationCodeRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(..., description="The authorization code")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
|
||||
code_verifier: str = Field(..., description="PKCE code verifier")
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str = Field(..., description="The refresh token")
|
||||
scope: str | None = Field(None, description="Optional scope parameter")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class TokenRequest(
|
||||
RootModel[
|
||||
Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
]
|
||||
):
|
||||
root: Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
|
||||
|
||||
class TokenErrorResponse(BaseModel):
|
||||
"""
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
"""
|
||||
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
error_uri: AnyHttpUrl | None = None
|
||||
|
||||
|
||||
class TokenSuccessResponse(RootModel[OAuthToken]):
|
||||
# this is just a wrapper over OAuthToken; the only reason we do this
|
||||
# is to have some separation between the HTTP response type, and the
|
||||
# type returned by the provider
|
||||
root: OAuthToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
||||
status_code = 200
|
||||
if isinstance(obj, TokenErrorResponse):
|
||||
status_code = 400
|
||||
|
||||
return PydanticJSONResponse(
|
||||
content=obj,
|
||||
status_code=status_code,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
async def handle(self, request: Request):
|
||||
try:
|
||||
client_info = await self.client_authenticator.authenticate_request(request)
|
||||
except AuthenticationError as e:
|
||||
# Authentication failures should return 401
|
||||
return PydanticJSONResponse(
|
||||
content=TokenErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
status_code=401,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
form_data = await request.form()
|
||||
token_request = TokenRequest.model_validate(dict(form_data)).root
|
||||
except ValidationError as validation_error: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
)
|
||||
)
|
||||
|
||||
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="unsupported_grant_type",
|
||||
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
|
||||
)
|
||||
)
|
||||
|
||||
tokens: OAuthToken
|
||||
|
||||
match token_request:
|
||||
case AuthorizationCodeRequest():
|
||||
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
||||
if auth_code is None or auth_code.client_id != token_request.client_id:
|
||||
# if code belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
# make auth codes expire after a deadline
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
|
||||
if auth_code.expires_at < time.time():
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# verify redirect_uri doesn't change between /authorize and /tokens
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
|
||||
if auth_code.redirect_uri_provided_explicitly:
|
||||
authorize_request_redirect_uri = auth_code.redirect_uri
|
||||
else: # pragma: no cover
|
||||
authorize_request_redirect_uri = None
|
||||
|
||||
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
|
||||
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
|
||||
auth_redirect_str = (
|
||||
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
|
||||
)
|
||||
|
||||
if token_redirect_str != auth_redirect_str:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=("redirect_uri did not match the one used when creating auth code"),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify PKCE code verifier
|
||||
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
||||
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
|
||||
if hashed_code_verifier != auth_code.code_challenge:
|
||||
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="incorrect code_verifier",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange authorization code for tokens
|
||||
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
case RefreshTokenRequest(): # pragma: no cover
|
||||
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
||||
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
||||
# if token belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
if refresh_token.expires_at and refresh_token.expires_at < time.time():
|
||||
# if the refresh token has expired, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# Parse scopes if provided
|
||||
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
||||
|
||||
for scope in scopes:
|
||||
if scope not in refresh_token.scopes:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_scope",
|
||||
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange refresh token for new tokens
|
||||
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
return self.response(TokenSuccessResponse(root=tokens))
|
||||
Reference in New Issue
Block a user