Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/a2a/compat/v0_3/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.grpc_handler import (
_ERROR_CODE_MAP,
CallContextBuilder,
DefaultCallContextBuilder,
DefaultGrpcContextBuilder,
GrpcContextBuilder,
)
from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.types.a2a_pb2 import AgentCard
Expand All @@ -44,7 +44,7 @@ def __init__(
self,
agent_card: AgentCard,
request_handler: RequestHandler,
context_builder: CallContextBuilder | None = None,
context_builder: GrpcContextBuilder | None = None,
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
| None = None,
):
Expand All @@ -54,14 +54,14 @@ def __init__(
agent_card: The AgentCard describing the agent's capabilities (v1.0).
request_handler: The underlying `RequestHandler` instance to
delegate requests to.
context_builder: The CallContextBuilder object. If none the
DefaultCallContextBuilder is used.
context_builder: Optional custom user builder to extract user from the
gRPC context.
card_modifier: An optional callback to dynamically modify the public
agent card before it is served.
"""
self.agent_card = agent_card
self.handler03 = RequestHandler03(request_handler=request_handler)
self.context_builder = context_builder or DefaultCallContextBuilder()
self._context_builder = context_builder or DefaultGrpcContextBuilder()
self.card_modifier = card_modifier

async def _handle_unary(
Expand All @@ -72,7 +72,7 @@ async def _handle_unary(
) -> TResponse:
"""Centralized error handling and context management for unary calls."""
try:
server_context = self.context_builder.build(context)
server_context = self._context_builder.build(context)
result = await handler_func(server_context)
self._set_extension_metadata(context, server_context)
except A2AError as e:
Expand All @@ -88,7 +88,7 @@ async def _handle_stream(
) -> AsyncIterable[TResponse]:
"""Centralized error handling and context management for streaming calls."""
try:
server_context = self.context_builder.build(context)
server_context = self._context_builder.build(context)
async for item in handler_func(server_context):
yield item
self._set_extension_metadata(context, server_context)
Expand Down
15 changes: 7 additions & 8 deletions src/a2a/compat/v0_3/jsonrpc_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from starlette.requests import Request

from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.routes import CallContextBuilder
from a2a.types.a2a_pb2 import AgentCard

_package_starlette_installed = True
Expand All @@ -38,6 +37,10 @@
from a2a.server.jsonrpc_models import (
JSONRPCError as CoreJSONRPCError,
)
from a2a.server.routes.common import (
ContextBuilder,
DefaultContextBuilder,
)
from a2a.utils import constants
from a2a.utils.errors import ExtendedAgentCardNotConfiguredError
from a2a.utils.helpers import maybe_await, validate_version
Expand Down Expand Up @@ -67,7 +70,7 @@ def __init__( # noqa: PLR0913
agent_card: 'AgentCard',
http_handler: 'RequestHandler',
extended_agent_card: 'AgentCard | None' = None,
context_builder: 'CallContextBuilder | None' = None,
context_builder: 'ContextBuilder | None' = None,
card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None,
extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None,
):
Expand All @@ -78,7 +81,7 @@ def __init__( # noqa: PLR0913
self.handler = RequestHandler03(
request_handler=http_handler,
)
self._context_builder = context_builder
self._context_builder = context_builder or DefaultContextBuilder()

def supports_method(self, method: str) -> bool:
"""Returns True if the v0.3 adapter supports the given method name."""
Expand Down Expand Up @@ -126,11 +129,7 @@ async def handle_request(
CoreInvalidRequestError(data=str(e)),
)

call_context = (
self._context_builder.build(request)
if self._context_builder
else ServerCallContext()
)
call_context = self._context_builder.build(request)
call_context.tenant = (
getattr(specific_request.params, 'tenant', '')
if hasattr(specific_request, 'params')
Expand Down
9 changes: 6 additions & 3 deletions src/a2a/compat/v0_3/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@
from a2a.compat.v0_3 import conversions
from a2a.compat.v0_3.rest_handler import REST03Handler
from a2a.server.context import ServerCallContext
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
from a2a.server.routes.common import (
ContextBuilder,
DefaultContextBuilder,
)
from a2a.utils.error_handlers import (
rest_error_handler,
rest_stream_error_handler,
Expand All @@ -60,7 +63,7 @@ def __init__( # noqa: PLR0913
agent_card: 'AgentCard',
http_handler: 'RequestHandler',
extended_agent_card: 'AgentCard | None' = None,
context_builder: 'CallContextBuilder | None' = None,
context_builder: 'ContextBuilder | None' = None,
card_modifier: 'Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] | None' = None,
extended_card_modifier: 'Callable[[AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard] | None' = None,
):
Expand All @@ -71,7 +74,7 @@ def __init__( # noqa: PLR0913
self.handler = REST03Handler(
agent_card=agent_card, request_handler=http_handler
)
self._context_builder = context_builder or DefaultCallContextBuilder()
self._context_builder = context_builder or DefaultContextBuilder()

@rest_error_handler
async def _handle_request(
Expand Down
56 changes: 29 additions & 27 deletions src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import a2a.types.a2a_pb2_grpc as a2a_grpc

from a2a import types
from a2a.auth.user import UnauthenticatedUser
from a2a.auth.user import UnauthenticatedUser, User
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
get_requested_extensions,
Expand All @@ -41,15 +41,32 @@

logger = logging.getLogger(__name__)

# For now we use a trivial wrapper on the grpc context object


class CallContextBuilder(ABC):
"""A class for building ServerCallContexts using the Starlette Request."""
class GrpcContextBuilder(ABC):
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
"""Interface for building ServerCallContext from gRPC context."""

@abstractmethod
def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
"""Builds a ServerCallContext from a gRPC Request."""
"""Builds a ServerCallContext from a gRPC ServicerContext."""


class DefaultGrpcContextBuilder(GrpcContextBuilder):
"""Default implementation of GrpcContextBuilder."""

def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
"""Builds a ServerCallContext from a gRPC ServicerContext."""
state = {'grpc_context': context}
return ServerCallContext(
user=self.build_user(context),
state=state,
requested_extensions=get_requested_extensions(
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
),
)

def build_user(self, context: grpc.aio.ServicerContext) -> User:
"""Builds a User from a gRPC ServicerContext."""
return UnauthenticatedUser()


def _get_metadata_value(
Expand All @@ -67,22 +84,6 @@ def _get_metadata_value(
]


class DefaultCallContextBuilder(CallContextBuilder):
"""A default implementation of CallContextBuilder."""

def build(self, context: grpc.aio.ServicerContext) -> ServerCallContext:
"""Builds the ServerCallContext."""
user = UnauthenticatedUser()
state = {'grpc_context': context}
return ServerCallContext(
user=user,
state=state,
requested_extensions=get_requested_extensions(
_get_metadata_value(context, HTTP_EXTENSION_HEADER)
),
)


_ERROR_CODE_MAP = {
types.InvalidRequestError: grpc.StatusCode.INVALID_ARGUMENT,
types.MethodNotFoundError: grpc.StatusCode.NOT_FOUND,
Expand Down Expand Up @@ -110,7 +111,7 @@ def __init__(
self,
agent_card: AgentCard,
request_handler: RequestHandler,
context_builder: CallContextBuilder | None = None,
context_builder: GrpcContextBuilder | None = None,
card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard]
| None = None,
):
Expand All @@ -120,14 +121,15 @@ def __init__(
agent_card: The AgentCard describing the agent's capabilities.
request_handler: The underlying `RequestHandler` instance to
delegate requests to.
context_builder: The CallContextBuilder object. If none the
DefaultCallContextBuilder is used.
context_builder: The GrpcContextBuilder used to construct the
ServerCallContext passed to the request_handler. If None the
DefaultGrpcContextBuilder is used.
card_modifier: An optional callback to dynamically modify the public
agent card before it is served.
"""
self.agent_card = agent_card
self.request_handler = request_handler
self.context_builder = context_builder or DefaultCallContextBuilder()
self._context_builder = context_builder or DefaultGrpcContextBuilder()
self.card_modifier = card_modifier

async def _handle_unary(
Expand Down Expand Up @@ -451,6 +453,6 @@ def _build_call_context(
context: grpc.aio.ServicerContext,
request: message.Message,
) -> ServerCallContext:
server_context = self.context_builder.build(context)
server_context = self._context_builder.build(context)
server_context.tenant = getattr(request, 'tenant', '')
return server_context
9 changes: 3 additions & 6 deletions src/a2a/server/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""A2A Routes."""

from a2a.server.routes.agent_card_routes import create_agent_card_routes
from a2a.server.routes.jsonrpc_dispatcher import (
CallContextBuilder,
DefaultCallContextBuilder,
)
from a2a.server.routes.common import ContextBuilder, DefaultContextBuilder
from a2a.server.routes.jsonrpc_routes import create_jsonrpc_routes
from a2a.server.routes.rest_routes import create_rest_routes


__all__ = [
'CallContextBuilder',
'DefaultCallContextBuilder',
'ContextBuilder',
'DefaultContextBuilder',
'create_agent_card_routes',
'create_jsonrpc_routes',
'create_rest_routes',
Expand Down
85 changes: 85 additions & 0 deletions src/a2a/server/routes/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any


if TYPE_CHECKING:
from starlette.authentication import BaseUser
from starlette.requests import Request
else:
try:
from starlette.authentication import BaseUser
from starlette.requests import Request
except ImportError:
Request = Any
BaseUser = Any

from a2a.auth.user import UnauthenticatedUser, User
from a2a.extensions.common import (
HTTP_EXTENSION_HEADER,
get_requested_extensions,
)
from a2a.server.context import ServerCallContext


class StarletteUser(User):
Comment thread
guglielmo-san marked this conversation as resolved.
"""Adapts a Starlette BaseUser to the A2A User interface."""

def __init__(self, user: BaseUser):
self._user = user

@property
def is_authenticated(self) -> bool:
"""Returns whether the current user is authenticated."""
return self._user.is_authenticated

@property
def user_name(self) -> str:
"""Returns the user name of the current user."""
return self._user.display_name
Comment thread
guglielmo-san marked this conversation as resolved.


class ContextBuilder(ABC):
Comment thread
guglielmo-san marked this conversation as resolved.
Outdated
"""A class for building ServerCallContexts using the Starlette Request."""

@abstractmethod
def build(self, request: Request) -> ServerCallContext:
"""Builds a ServerCallContext from a Starlette Request."""


class DefaultContextBuilder(ContextBuilder):
"""A default implementation of ContextBuilder."""

def build(self, request: Request) -> ServerCallContext:
"""Builds a ServerCallContext from a Starlette Request.

Args:
request: The incoming Starlette Request object.

Returns:
A ServerCallContext instance populated with user and state
information from the request.
"""
state = {}
if 'auth' in request.scope:
state['auth'] = request.auth
state['headers'] = dict(request.headers)
return ServerCallContext(
user=self.build_user(request),
state=state,
requested_extensions=get_requested_extensions(
request.headers.getlist(HTTP_EXTENSION_HEADER)
),
)

def build_user(self, request: Request) -> User:
"""Builds a User from a Starlette Request.

Args:
request: The incoming Starlette Request object.

Returns:
A User instance populated with user information from the request.
"""
if 'user' in request.scope:
return StarletteUser(request.user)
return UnauthenticatedUser()
Loading
Loading