Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from typing import Any, NewType, Protocol
from urllib.parse import quote, urlencode, urljoin, urlparse

import anyio
Expand Down Expand Up @@ -53,6 +53,8 @@

logger = logging.getLogger(__name__)

AuthorizationState = NewType("AuthorizationState", str)


class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""
Expand Down Expand Up @@ -305,14 +307,10 @@ async def _perform_authorization(self) -> httpx.Request:
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
return token_request

async def _perform_authorization_code_grant(self) -> tuple[str, str]:
"""Perform the authorization redirect and get auth code."""
async def _build_authorization_url(self) -> tuple[str, AuthorizationState, PKCEParameters]:
"""Build authorization URL and state."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.redirect_handler:
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
if not self.context.callback_handler:
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover

if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
Expand All @@ -325,7 +323,7 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:

# Generate PKCE parameters
pkce_params = PKCEParameters.generate()
state = secrets.token_urlsafe(32)
state = AuthorizationState(secrets.token_urlsafe(32))

auth_params = {
"response_type": "code",
Expand All @@ -344,6 +342,17 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]:
auth_params["scope"] = self.context.client_metadata.scope

authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"

return authorization_url, state, pkce_params

async def _perform_authorization_code_grant(self) -> tuple[str, str]:
"""Perform the authorization redirect and get auth code."""
if not self.context.redirect_handler:
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
if not self.context.callback_handler:
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover

authorization_url, state, pkce_params = await self._build_authorization_url()
await self.context.redirect_handler(authorization_url)

# Wait for callback
Expand Down