From c615757ba12093ba4a2ba19bee3f498fef91584c Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 24 Feb 2026 08:34:05 -0800 Subject: [PATCH 1/8] fix: Add support for injecting a custom google.genai.Client into Gemini models This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings Close #2560 Co-authored-by: George Weale PiperOrigin-RevId: 874628604 --- src/google/adk/models/google_llm.py | 56 ++++++++ tests/unittests/models/test_google_llm.py | 150 ++++++++++++++++++++++ 2 files changed, 206 insertions(+) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 23c9c27810..b8c5117e19 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -85,6 +85,23 @@ class Gemini(BaseLlm): Attributes: model: The name of the Gemini model. + client: An optional preconfigured ``google.genai.Client`` instance. + When provided, ADK uses this client for all API calls instead of + creating one internally from environment variables or ADC. This + allows fine-grained control over authentication, project, location, + and other client-level settings — and enables running agents that + target different Vertex AI regions within the same process. + + Example:: + + from google import genai + from google.adk.models import Gemini + + client = genai.Client( + vertexai=True, project="my-project", location="us-central1" + ) + model = Gemini(model="gemini-2.5-flash", client=client) + use_interactions_api: Whether to use the interactions API for model invocation. """ @@ -131,6 +148,35 @@ class Gemini(BaseLlm): ``` """ + def __init__(self, *, client: Optional[Client] = None, **kwargs: Any): + """Initialises a Gemini model wrapper. + + Args: + client: An optional preconfigured ``google.genai.Client``. When + provided, ADK uses this client for **all** Gemini API calls + (including the Live API) instead of creating one internally. + + .. note:: + When a custom client is supplied it is used as-is for Live API + connections. ADK will **not** override the client's + ``api_version``; you are responsible for setting the correct + version (``v1beta1`` for Vertex AI, ``v1alpha`` for the + Gemini developer API) on the client yourself. + + .. warning:: + ``google.genai.Client`` contains threading primitives that + cannot be pickled. If you are deploying to Agent Engine (or + any environment that serialises the model), do **not** pass a + custom client — let ADK create one from the environment + instead. + + **kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor + (``model``, ``base_url``, ``retry_options``, etc.). + """ + super().__init__(**kwargs) + # Store after super().__init__ so Pydantic validation runs first. + object.__setattr__(self, '_client', client) + @classmethod @override def supported_models(cls) -> list[str]: @@ -299,9 +345,16 @@ async def _generate_content_via_interactions( def api_client(self) -> Client: """Provides the api client. + If a preconfigured ``client`` was passed to the constructor it is + returned directly; otherwise a new ``Client`` is created using the + default environment/ADC configuration. + Returns: The api client. """ + if self._client is not None: + return self._client + from google.genai import Client return Client( @@ -334,6 +387,9 @@ def _live_api_version(self) -> str: @cached_property def _live_api_client(self) -> Client: + if self._client is not None: + return self._client + from google.genai import Client return Client( diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 70aa01b69d..75d4c0fd48 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -2140,3 +2140,153 @@ async def __aexit__(self, *args): # Verify the final speech_config is still None assert config_arg.speech_config is None assert isinstance(connection, GeminiLlmConnection) + + +# --------------------------------------------------------------------------- +# Tests for custom client injection (Issue #2560) +# --------------------------------------------------------------------------- + + +def test_custom_client_is_used_for_api_client(): + """When a custom client is provided, api_client returns it directly.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini.api_client is custom_client + + +def test_custom_client_is_used_for_live_api_client(): + """When a custom client is provided, _live_api_client returns it directly.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._live_api_client is custom_client + + +def test_default_api_client_when_no_custom_client(): + """Without a custom client, api_client creates a default Client.""" + gemini = Gemini(model="gemini-1.5-flash") + + # api_client should construct a real Client (not None) + client = gemini.api_client + assert client is not None + # Verify it is not a mock — it's a real google.genai.Client + from google.genai import Client + + assert isinstance(client, Client) + + +def test_default_live_api_client_when_no_custom_client(): + """Without a custom client, _live_api_client creates a default Client.""" + gemini = Gemini(model="gemini-1.5-flash") + + client = gemini._live_api_client + assert client is not None + from google.genai import Client + + assert isinstance(client, Client) + + +def test_custom_client_api_backend_vertexai(): + """_api_backend reflects the custom client's vertexai setting.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = True + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI + + +def test_custom_client_api_backend_gemini_api(): + """_api_backend reflects non-vertexai custom client.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + assert gemini._api_backend == GoogleLLMVariant.GEMINI_API + + +@pytest.mark.asyncio +async def test_custom_client_used_for_generate_content(): + """Custom client is used when generate_content_async is called.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + generate_content_response = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + role="model", + parts=[Part.from_text(text="Hello from custom client")], + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + async def mock_coro(): + return generate_content_response + + custom_client.aio.models.generate_content.return_value = mock_coro() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + + responses = [ + resp + async for resp in gemini.generate_content_async(llm_request, stream=False) + ] + + assert len(responses) == 1 + assert responses[0].content.parts[0].text == "Hello from custom client" + custom_client.aio.models.generate_content.assert_called_once() + + +@pytest.mark.asyncio +async def test_custom_client_used_for_live_connect(): + """Custom client is used for live API streaming connections.""" + from google.genai import Client + + custom_client = mock.MagicMock(spec=Client) + custom_client.vertexai = False + gemini = Gemini(model="gemini-1.5-flash", client=custom_client) + + mock_live_session = mock.AsyncMock() + + class MockLiveConnect: + + async def __aenter__(self): + return mock_live_session + + async def __aexit__(self, *args): + pass + + custom_client.aio.live.connect.return_value = MockLiveConnect() + + llm_request = LlmRequest( + model="gemini-1.5-flash", + contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], + config=types.GenerateContentConfig( + system_instruction="You are a helpful assistant", + ), + ) + llm_request.live_connect_config = types.LiveConnectConfig() + + async with gemini.connect(llm_request) as connection: + custom_client.aio.live.connect.assert_called_once() + assert isinstance(connection, GeminiLlmConnection) From 7be90db24b41f1830e39ca3d7e15bf4dbfa5a304 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 24 Feb 2026 08:38:34 -0800 Subject: [PATCH 2/8] feat: Support ID token exchange in ServiceAccountCredentialExchanger Adds use_id_token and audience fields to ServiceAccount so that ServiceAccountCredentialExchanger can produce ID tokens instead of access tokens. This is required for authenticating to Cloud Run, Cloud Functions, and other Google Cloud services that verify caller identity. Close #4458 Co-authored-by: George Weale PiperOrigin-RevId: 874630210 --- src/google/adk/auth/auth_credential.py | 39 ++- .../service_account_exchanger.py | 136 ++++++-- .../unittests/tools/mcp_tool/test_mcp_tool.py | 4 +- .../test_service_account_exchanger.py | 329 +++++++++++++----- 4 files changed, 402 insertions(+), 106 deletions(-) diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index e205d9be52..6160edcc02 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -25,6 +25,7 @@ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator class BaseModelWithConfig(BaseModel): @@ -145,11 +146,45 @@ class ServiceAccountCredential(BaseModelWithConfig): class ServiceAccount(BaseModelWithConfig): - """Represents Google Service Account configuration.""" + """Represents Google Service Account configuration. + + Attributes: + service_account_credential: The service account credential (JSON key). + scopes: The OAuth2 scopes to request. Optional; when omitted with + ``use_default_credential=True``, defaults to the cloud-platform scope. + use_default_credential: Whether to use Application Default Credentials. + use_id_token: Whether to exchange for an ID token instead of an access + token. Required for service-to-service authentication with Cloud Run, + Cloud Functions, and other Google Cloud services that require identity + verification. When True, ``audience`` must also be set. + audience: The target audience for the ID token, typically the URL of the + receiving service (e.g. ``https://my-service-xyz.run.app``). Required + when ``use_id_token`` is True. + """ service_account_credential: Optional[ServiceAccountCredential] = None - scopes: List[str] + scopes: Optional[List[str]] = None use_default_credential: Optional[bool] = False + use_id_token: Optional[bool] = False + audience: Optional[str] = None + + @model_validator(mode="after") + def _validate_config(self) -> ServiceAccount: + if ( + not self.use_default_credential + and self.service_account_credential is None + ): + raise ValueError( + "service_account_credential is required when" + " use_default_credential is False." + ) + if self.use_id_token and not self.audience: + raise ValueError( + "audience is required when use_id_token is True. Set it to the" + " URL of the target service" + " (e.g. 'https://my-service.run.app')." + ) + return self class AuthCredentialTypes(str, Enum): diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 1dbe0fe46a..2b79edf997 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -19,6 +19,7 @@ from typing import Optional import google.auth +from google.auth import exceptions as google_auth_exceptions from google.auth.transport.requests import Request from google.oauth2 import service_account import google.oauth2.credentials @@ -27,6 +28,7 @@ from .....auth.auth_credential import AuthCredentialTypes from .....auth.auth_credential import HttpAuth from .....auth.auth_credential import HttpCredentials +from .....auth.auth_credential import ServiceAccount from .....auth.auth_schemes import AuthScheme from .base_credential_exchanger import AuthCredentialMissingError from .base_credential_exchanger import BaseAuthCredentialExchanger @@ -38,6 +40,11 @@ class ServiceAccountCredentialExchanger(BaseAuthCredentialExchanger): Uses the default service credential if `use_default_credential = True`. Otherwise, uses the service account credential provided in the auth credential. + + Supports exchanging for either an access token (default) or an ID token + when ``ServiceAccount.use_id_token`` is True. ID tokens are required for + service-to-service authentication with Cloud Run, Cloud Functions, and + other services that verify caller identity. """ def exchange_credential( @@ -45,52 +52,130 @@ def exchange_credential( auth_scheme: AuthScheme, auth_credential: Optional[AuthCredential] = None, ) -> AuthCredential: - """Exchanges the service account auth credential for an access token. + """Exchanges the service account auth credential for a token. If auth_credential contains a service account credential, it will be used - to fetch an access token. Otherwise, the default service credential will be - used for fetching an access token. + to fetch a token. Otherwise, the default service credential will be + used for fetching a token. + + When ``service_account.use_id_token`` is True, an ID token is fetched + using the configured ``audience``. This is required for authenticating + to Cloud Run, Cloud Functions, and similar services. Args: auth_scheme: The auth scheme. auth_credential: The auth credential. Returns: - An AuthCredential in HTTPBearer format, containing the access token. + An AuthCredential in HTTPBearer format, containing the token. """ - if ( - auth_credential is None - or auth_credential.service_account is None - or ( - auth_credential.service_account.service_account_credential is None - and not auth_credential.service_account.use_default_credential - ) - ): + if auth_credential is None or auth_credential.service_account is None: raise AuthCredentialMissingError( - "Service account credentials are missing. Please provide them, or set" - " `use_default_credential = True` to use application default" + "Service account credentials are missing. Please provide them, or" + " set `use_default_credential = True` to use application default" " credential in a hosted service like Cloud Run." ) + sa_config = auth_credential.service_account + + if sa_config.use_id_token: + return self._exchange_for_id_token(sa_config) + + return self._exchange_for_access_token(sa_config) + + def _exchange_for_id_token(self, sa_config: ServiceAccount) -> AuthCredential: + """Exchanges the service account credential for an ID token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the ID token. + + Raises: + AuthCredentialMissingError: If token exchange fails. + """ + # audience and credential presence are validated by the ServiceAccount + # model_validator at construction time. try: - if auth_credential.service_account.use_default_credential: - credentials, project_id = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"], + if sa_config.use_default_credential: + from google.oauth2 import id_token as oauth2_id_token + + request = Request() + token = oauth2_id_token.fetch_id_token(request, sa_config.audience) + else: + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None + credentials = ( + service_account.IDTokenCredentials.from_service_account_info( + sa_config.service_account_credential.model_dump(), + target_audience=sa_config.audience, + ) ) - quota_project_id = ( - getattr(credentials, "quota_project_id", None) or project_id + credentials.refresh(Request()) + token = credentials.token + + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(token=token), + ), + ) + + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key), or when + # fetch_id_token cannot determine credentials from the environment. + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: + raise AuthCredentialMissingError( + f"Failed to exchange service account for ID token: {e}" + ) from e + + def _exchange_for_access_token( + self, sa_config: ServiceAccount + ) -> AuthCredential: + """Exchanges the service account credential for an access token. + + Args: + sa_config: The service account configuration. + + Returns: + An AuthCredential in HTTPBearer format containing the access token. + + Raises: + AuthCredentialMissingError: If scopes are missing for explicit + credentials or token exchange fails. + """ + if not sa_config.use_default_credential and not sa_config.scopes: + raise AuthCredentialMissingError( + "scopes are required when using explicit service account credentials" + " for access token exchange." + ) + + try: + if sa_config.use_default_credential: + scopes = ( + sa_config.scopes + if sa_config.scopes + else ["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials, project_id = google.auth.default( + scopes=scopes, ) + quota_project_id = credentials.quota_project_id or project_id else: - config = auth_credential.service_account + # Guaranteed non-None by ServiceAccount model_validator. + assert sa_config.service_account_credential is not None credentials = service_account.Credentials.from_service_account_info( - config.service_account_credential.model_dump(), scopes=config.scopes + sa_config.service_account_credential.model_dump(), + scopes=sa_config.scopes, ) quota_project_id = None credentials.refresh(Request()) - updated_credential = AuthCredential( - auth_type=AuthCredentialTypes.HTTP, # Store as a bearer token + return AuthCredential( + auth_type=AuthCredentialTypes.HTTP, http=HttpAuth( scheme="bearer", credentials=HttpCredentials(token=credentials.token), @@ -101,9 +186,10 @@ def exchange_credential( else None, ), ) - return updated_credential - except Exception as e: + # ValueError is raised by google-auth when service account JSON is + # missing required fields (e.g. client_email, private_key). + except (google_auth_exceptions.GoogleAuthError, ValueError) as e: raise AuthCredentialMissingError( f"Failed to exchange service account token: {e}" ) from e diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index c4c85e7769..f38a8bbc7a 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -534,7 +534,9 @@ async def test_get_headers_service_account(self): ) # Create service account credential - service_account = ServiceAccount(scopes=["test"]) + service_account = ServiceAccount( + scopes=["test"], use_default_credential=True + ) credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=service_account, diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index 0ca9944423..fb35daf64f 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -25,8 +25,23 @@ from google.adk.tools.openapi_tool.auth.credential_exchangers.base_credential_exchanger import AuthCredentialMissingError from google.adk.tools.openapi_tool.auth.credential_exchangers.service_account_exchanger import ServiceAccountCredentialExchanger import google.auth +from google.auth import exceptions as google_auth_exceptions import pytest +_ACCESS_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.Credentials." + "from_service_account_info" +) + +_ID_TOKEN_MONKEYPATCH_TARGET = ( + "google.adk.tools.openapi_tool.auth.credential_exchangers." + "service_account_exchanger.service_account.IDTokenCredentials." + "from_service_account_info" +) + +_FETCH_ID_TOKEN_MONKEYPATCH_TARGET = "google.oauth2.id_token.fetch_id_token" + @pytest.fixture def service_account_exchanger(): @@ -41,50 +56,45 @@ def auth_scheme(): return scheme -def test_exchange_credential_success( - service_account_exchanger, auth_scheme, monkeypatch +@pytest.fixture +def sa_credential(): + """A minimal valid ServiceAccountCredential for testing.""" + return ServiceAccountCredential( + type_="service_account", + project_id="test_project_id", + private_key_id="test_private_key_id", + private_key="-----BEGIN PRIVATE KEY-----...", + client_email="test@test.iam.gserviceaccount.com", + client_id="test_client_id", + auth_uri="https://accounts.google.com/o/oauth2/auth", + token_uri="https://oauth2.googleapis.com/token", + auth_provider_x509_cert_url="https://www.googleapis.com/oauth2/v1/certs", + client_x509_cert_url=( + "https://www.googleapis.com/robot/v1/metadata/x509/test" + ), + universe_domain="googleapis.com", + ) + + +_DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] + + +# --- Access token exchange tests --- + + +def test_exchange_access_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch ): - """Test successful exchange of service account credentials.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_from_sa_info = MagicMock(return_value=mock_credentials) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) - # Mock the from_service_account_info method - mock_from_service_account_info = MagicMock(return_value=mock_credentials) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" - ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, - ) - - # Mock the refresh method - mock_credentials.refresh = MagicMock() - - # Create a valid AuthCredential with service account info auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, ), ) @@ -95,7 +105,7 @@ def test_exchange_credential_success( assert result.auth_type == AuthCredentialTypes.HTTP assert result.http.scheme == "bearer" assert result.http.credentials.token == "mock_access_token" - mock_from_service_account_info.assert_called_once() + mock_from_sa_info.assert_called_once() mock_credentials.refresh.assert_called_once() @@ -107,7 +117,7 @@ def test_exchange_credential_success( (None, None, None), ], ) -def test_exchange_credential_use_default_credential_success( +def test_exchange_access_token_with_adc_sets_quota_project( service_account_exchanger, auth_scheme, monkeypatch, @@ -115,7 +125,6 @@ def test_exchange_credential_use_default_credential_success( adc_project_id, expected_quota_project_id, ): - """Test successful exchange of service account credentials using default credential.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" mock_credentials.quota_project_id = cred_quota_project_id @@ -128,7 +137,7 @@ def test_exchange_credential_use_default_credential_success( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( use_default_credential=True, - scopes=["https://www.googleapis.com/auth/cloud-platform"], + scopes=["https://www.googleapis.com/auth/bigquery"], ), ) @@ -146,26 +155,49 @@ def test_exchange_credential_use_default_credential_success( ) else: assert not result.http.additional_headers - # Verify google.auth.default is called with the correct scopes parameter mock_google_auth_default.assert_called_once_with( - scopes=["https://www.googleapis.com/auth/cloud-platform"] + scopes=["https://www.googleapis.com/auth/bigquery"] ) mock_credentials.refresh.assert_called_once() -def test_exchange_credential_missing_auth_credential( +def test_exchange_access_token_with_adc_defaults_to_cloud_platform_scope( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_credentials = MagicMock() + mock_credentials.token = "mock_access_token" + mock_credentials.quota_project_id = None + mock_google_auth_default = MagicMock(return_value=(mock_credentials, None)) + monkeypatch.setattr(google.auth, "default", mock_google_auth_default) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_access_token" + mock_google_auth_default.assert_called_once_with(scopes=_DEFAULT_SCOPES) + + +def test_exchange_raises_when_auth_credential_is_none( service_account_exchanger, auth_scheme ): - """Test missing auth credential during exchange.""" with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, None) assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_missing_service_account_info( +def test_exchange_raises_when_service_account_is_none( service_account_exchanger, auth_scheme ): - """Test missing service account info during exchange.""" auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, ) @@ -174,47 +206,188 @@ def test_exchange_credential_missing_service_account_info( assert "Service account credentials are missing" in str(exc_info.value) -def test_exchange_credential_exchange_failure( +def test_exchange_wraps_google_auth_error_as_missing_error( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to load credentials") + ) + monkeypatch.setattr(_ACCESS_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account token" in str(exc_info.value) + mock_from_sa_info.assert_called_once() + + +def test_exchange_raises_when_explicit_credentials_have_no_scopes( + service_account_exchanger, auth_scheme, sa_credential +): + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "scopes are required" in str(exc_info.value) + + +# --- ID token exchange tests --- + + +def test_exchange_id_token_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_id_credentials = MagicMock() + mock_id_credentials.token = "mock_id_token" + mock_from_sa_info = MagicMock(return_value=mock_id_credentials) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential + ) + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_id_token" + assert result.http.additional_headers is None + mock_from_sa_info.assert_called_once() + assert ( + mock_from_sa_info.call_args[1]["target_audience"] + == "https://my-service.run.app" + ) + mock_id_credentials.refresh.assert_called_once() + + +def test_exchange_id_token_with_adc( service_account_exchanger, auth_scheme, monkeypatch ): - """Test failure during service account token exchange.""" - mock_from_service_account_info = MagicMock( - side_effect=Exception("Failed to load credentials") + mock_fetch_id_token = MagicMock(return_value="mock_adc_id_token") + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), ) - target_path = ( - "google.adk.tools.openapi_tool.auth.credential_exchangers." - "service_account_exchanger.service_account.Credentials." - "from_service_account_info" + + result = service_account_exchanger.exchange_credential( + auth_scheme, auth_credential ) - monkeypatch.setattr( - target_path, - mock_from_service_account_info, + + assert result.auth_type == AuthCredentialTypes.HTTP + assert result.http.scheme == "bearer" + assert result.http.credentials.token == "mock_adc_id_token" + assert result.http.additional_headers is None + mock_fetch_id_token.assert_called_once() + assert mock_fetch_id_token.call_args[0][1] == "https://my-service.run.app" + + +def test_id_token_requires_audience(): + with pytest.raises( + ValueError, match="audience is required when use_id_token is True" + ): + ServiceAccount( + use_default_credential=True, + use_id_token=True, + ) + + +def test_exchange_id_token_wraps_error_with_explicit_credentials( + service_account_exchanger, auth_scheme, sa_credential, monkeypatch +): + mock_from_sa_info = MagicMock( + side_effect=ValueError("Failed to create ID token credentials") ) + monkeypatch.setattr(_ID_TOKEN_MONKEYPATCH_TARGET, mock_from_sa_info) auth_credential = AuthCredential( auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, service_account=ServiceAccount( - service_account_credential=ServiceAccountCredential( - type_="service_account", - project_id="your_project_id", - private_key_id="your_private_key_id", - private_key="-----BEGIN PRIVATE KEY-----...", - client_email="...@....iam.gserviceaccount.com", - client_id="your_client_id", - auth_uri="https://accounts.google.com/o/oauth2/auth", - token_uri="https://oauth2.googleapis.com/token", - auth_provider_x509_cert_url=( - "https://www.googleapis.com/oauth2/v1/certs" - ), - client_x509_cert_url=( - "https://www.googleapis.com/robot/v1/metadata/x509/..." - ), - universe_domain="googleapis.com", - ), - scopes=["https://www.googleapis.com/auth/cloud-platform"], + service_account_credential=sa_credential, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", ), ) + with pytest.raises(AuthCredentialMissingError) as exc_info: service_account_exchanger.exchange_credential(auth_scheme, auth_credential) - assert "Failed to exchange service account token" in str(exc_info.value) - mock_from_service_account_info.assert_called_once() + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +def test_exchange_id_token_wraps_error_with_adc( + service_account_exchanger, auth_scheme, monkeypatch +): + mock_fetch_id_token = MagicMock( + side_effect=google_auth_exceptions.DefaultCredentialsError( + "Metadata service unavailable" + ) + ) + monkeypatch.setattr(_FETCH_ID_TOKEN_MONKEYPATCH_TARGET, mock_fetch_id_token) + + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.SERVICE_ACCOUNT, + service_account=ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + use_id_token=True, + audience="https://my-service.run.app", + ), + ) + + with pytest.raises(AuthCredentialMissingError) as exc_info: + service_account_exchanger.exchange_credential(auth_scheme, auth_credential) + assert "Failed to exchange service account for ID token" in str( + exc_info.value + ) + + +# --- Model validator tests --- + + +def test_model_validator_rejects_missing_credential_without_adc(): + with pytest.raises( + ValueError, + match="service_account_credential is required", + ): + ServiceAccount( + use_default_credential=False, + scopes=_DEFAULT_SCOPES, + ) + + +def test_model_validator_allows_adc_without_explicit_credential(): + sa = ServiceAccount( + use_default_credential=True, + scopes=_DEFAULT_SCOPES, + ) + assert sa.service_account_credential is None + assert sa.use_default_credential is True From ee8d956413473d1bbbb025a470ad882c1487d8b8 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Feb 2026 11:15:35 -0800 Subject: [PATCH 3/8] fix: Update agent_engine_sandbox_code_executor in ADK 1. For prototyping and testing purposes, sandbox name can be provided, and it will be used for all requests across the lifecycle of an agent 2. If no sandbox name is provided, agent engine name will be provided, and we will automatically create one sandbox per session, and the sandbox has TTL set for a year. If the sandbox stored in the session hits the TTL, it will not be in "STATE_RUNNING" so a new sandbox will be created. PiperOrigin-RevId: 874705260 --- .../agent_engine_code_execution/README | 4 +- .../agent_engine_code_execution/agent.py | 7 +- .../agent_engine_sandbox_code_executor.py | 54 ++----- ...test_agent_engine_sandbox_code_executor.py | 133 ------------------ 4 files changed, 19 insertions(+), 179 deletions(-) diff --git a/contributing/samples/agent_engine_code_execution/README b/contributing/samples/agent_engine_code_execution/README index b0443ae228..8d5a444237 100644 --- a/contributing/samples/agent_engine_code_execution/README +++ b/contributing/samples/agent_engine_code_execution/README @@ -7,9 +7,9 @@ This sample data science agent uses Agent Engine Code Execution Sandbox to execu ## How to use -* 1. Follow https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create-an-agent-engine-instance to create an agent engine instance. Replace the AGENT_ENGINE_RESOURCE_NAME with the one you just created. A new sandbox environment under this agent engine instance will be created for each session with TTL of 1 year. But sandbox can only main its state for up to 14 days. This is the recommended usage for production environments. +* 1. Follow https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/code-execution/overview to create a code execution sandbox environment. -* 2. For testing or protyping purposes, create a sandbox environment by following this guide: https://docs.cloud.google.com/agent-builder/agent-engine/code-execution/quickstart#create_a_sandbox. Replace the SANDBOX_RESOURCE_NAME with the one you just created. This will be used as the default sandbox environment for all the code executions throughout the lifetime of the agent. As the sandbox is re-used across sessions, all sessions will share the same Python environment and variable values." +* 2. Replace the SANDBOX_RESOURCE_NAME with the one you just created. If you dont want to create a new sandbox environment directly, the Agent Engine Code Execution Sandbox will create one for you by default using the AGENT_ENGINE_RESOURCE_NAME you specified, however, please ensure to clean up sandboxes after use; otherwise, it will consume quotas. ## Sample prompt diff --git a/contributing/samples/agent_engine_code_execution/agent.py b/contributing/samples/agent_engine_code_execution/agent.py index a32e4ca4c6..d85989eb2d 100644 --- a/contributing/samples/agent_engine_code_execution/agent.py +++ b/contributing/samples/agent_engine_code_execution/agent.py @@ -85,10 +85,11 @@ def base_system_instruction(): """, code_executor=AgentEngineSandboxCodeExecutor( - # Replace with your sandbox resource name if you already have one. Only use it for testing or prototyping purposes, because this will use the same sandbox for all requests. + # Replace with your sandbox resource name if you already have one. + sandbox_resource_name="SANDBOX_RESOURCE_NAME", # "projects/vertex-agent-loadtest/locations/us-central1/reasoningEngines/6842889780301135872/sandboxEnvironments/6545148628569161728", - sandbox_resource_name=None, - # Replace with agent engine resource name used for creating sandbox environment. + # Replace with agent engine resource name used for creating sandbox if + # sandbox_resource_name is not set. agent_engine_resource_name="AGENT_ENGINE_RESOURCE_NAME", ), ) diff --git a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py index 9348dbc458..69d1778a5c 100644 --- a/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py +++ b/src/google/adk/code_executors/agent_engine_sandbox_code_executor.py @@ -21,7 +21,6 @@ from typing import Optional from typing_extensions import override -from vertexai import types from ..agents.invocation_context import InvocationContext from .base_code_executor import BaseCodeExecutor @@ -39,15 +38,10 @@ class AgentEngineSandboxCodeExecutor(BaseCodeExecutor): sandbox_resource_name: If set, load the existing resource name of the code interpreter extension instead of creating a new one. Format: projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789 - agent_engine_resource_name: The resource name of the agent engine to use - to create the code execution sandbox. Format: - projects/123/locations/us-central1/reasoningEngines/456 """ sandbox_resource_name: str = None - agent_engine_resource_name: str = None - def __init__( self, sandbox_resource_name: Optional[str] = None, @@ -73,19 +67,30 @@ def __init__( agent_engine_resource_name_pattern = r'^projects/([a-zA-Z0-9-_]+)/locations/([a-zA-Z0-9-_]+)/reasoningEngines/(\d+)$' if sandbox_resource_name is not None: + self.sandbox_resource_name = sandbox_resource_name self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( sandbox_resource_name, sandbox_resource_name_pattern ) ) - self.sandbox_resource_name = sandbox_resource_name elif agent_engine_resource_name is not None: + from vertexai import types + self._project_id, self._location = ( self._get_project_id_and_location_from_resource_name( agent_engine_resource_name, agent_engine_resource_name_pattern ) ) - self.agent_engine_resource_name = agent_engine_resource_name + # @TODO - Add TTL for sandbox creation after it is available + # in SDK. + operation = self._get_api_client().agent_engines.sandboxes.create( + spec={'code_execution_environment': {}}, + name=agent_engine_resource_name, + config=types.CreateAgentEngineSandboxConfig( + display_name='default_sandbox' + ), + ) + self.sandbox_resource_name = operation.response.name else: raise ValueError( 'Either sandbox_resource_name or agent_engine_resource_name must be' @@ -98,39 +103,6 @@ def execute_code( invocation_context: InvocationContext, code_execution_input: CodeExecutionInput, ) -> CodeExecutionResult: - if self.sandbox_resource_name is None: - sandbox_name = invocation_context.session.state.get('sandbox_name', None) - create_new_sandbox = False - if sandbox_name is None: - create_new_sandbox = True - else: - # Check if the sandbox is still running OR already expired due to ttl. - sandbox = self._get_api_client().agent_engines.sandboxes.get( - name=sandbox_name - ) - if not sandbox or sandbox.state != 'STATE_RUNNING': - create_new_sandbox = True - - if create_new_sandbox: - operation = self._get_api_client().agent_engines.sandboxes.create( - spec={'code_execution_environment': {}}, - name=self.agent_engine_resource_name, - config=types.CreateAgentEngineSandboxConfig( - # VertexAiSessionService has a default TTL of 1 year, so we set - # the sandbox TTL to 1 year as well. For the current code - # execution sandbox, if it hasn't been used for 14 days, the - # state will be lost. - display_name='default_sandbox', - ttl='31536000s', - ), - ) - self.sandbox_resource_name = operation.response.name - invocation_context.session.state['sandbox_name'] = ( - self.sandbox_resource_name - ) - else: - self.sandbox_resource_name = sandbox_name - # Execute the code. input_data = { 'code': code_execution_input.code, diff --git a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py index 604685fe8f..6022527f9c 100644 --- a/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py +++ b/tests/unittests/code_executors/test_agent_engine_sandbox_code_executor.py @@ -19,7 +19,6 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.code_executors.agent_engine_sandbox_code_executor import AgentEngineSandboxCodeExecutor from google.adk.code_executors.code_execution_utils import CodeExecutionInput -from google.adk.sessions.session import Session import pytest @@ -28,10 +27,6 @@ def mock_invocation_context() -> InvocationContext: """Fixture for a mock InvocationContext.""" mock = MagicMock(spec=InvocationContext) mock.invocation_id = "test-invocation-123" - session = MagicMock(spec=Session) - mock.session = session - session.state = [] - return mock @@ -123,131 +118,3 @@ def test_execute_code_success( name="projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789", input_data={"code": 'print("hello world")'}, ) - - @patch("vertexai.Client") - def test_execute_code_recreates_sandbox_when_get_returns_none( - self, - mock_vertexai_client, - mock_invocation_context, - ): - # Setup Mocks - mock_api_client = MagicMock() - mock_vertexai_client.return_value = mock_api_client - - # Existing sandbox name stored in session, but get() will return None - existing_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/old" - mock_invocation_context.session.state = { - "sandbox_name": existing_sandbox_name - } - - # Mock get to return None (simulating missing/expired sandbox) - mock_api_client.agent_engines.sandboxes.get.return_value = None - - # Mock create operation to return a new sandbox resource name - operation_mock = MagicMock() - created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" - operation_mock.response.name = created_sandbox_name - mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock - - # Mock execute_code response - mock_response = MagicMock() - mock_json_output = MagicMock() - mock_json_output.mime_type = "application/json" - mock_json_output.data = json.dumps( - {"stdout": "recreated sandbox run", "stderr": ""} - ).encode("utf-8") - mock_json_output.metadata = None - mock_response.outputs = [mock_json_output] - mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( - mock_response - ) - - # Execute using agent_engine_resource_name so a sandbox can be created - executor = AgentEngineSandboxCodeExecutor( - agent_engine_resource_name=( - "projects/123/locations/us-central1/reasoningEngines/456" - ) - ) - code_input = CodeExecutionInput(code='print("hello world")') - result = executor.execute_code(mock_invocation_context, code_input) - - # Assert get was called for the existing sandbox - mock_api_client.agent_engines.sandboxes.get.assert_called_once_with( - name=existing_sandbox_name - ) - - # Assert create was called and session updated with new sandbox - mock_api_client.agent_engines.sandboxes.create.assert_called_once() - assert executor.sandbox_resource_name == created_sandbox_name - assert ( - mock_invocation_context.session.state["sandbox_name"] - == created_sandbox_name - ) - - # Assert execute_code used the created sandbox name - mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( - name=created_sandbox_name, - input_data={"code": 'print("hello world")'}, - ) - - @patch("vertexai.Client") - def test_execute_code_creates_sandbox_if_missing( - self, - mock_vertexai_client, - mock_invocation_context, - ): - # Setup Mocks - mock_api_client = MagicMock() - mock_vertexai_client.return_value = mock_api_client - - # Mock create operation to return a sandbox resource name - operation_mock = MagicMock() - created_sandbox_name = "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" - operation_mock.response.name = created_sandbox_name - mock_api_client.agent_engines.sandboxes.create.return_value = operation_mock - - # Mock execute_code response - mock_response = MagicMock() - mock_json_output = MagicMock() - mock_json_output.mime_type = "application/json" - mock_json_output.data = json.dumps( - {"stdout": "created sandbox run", "stderr": ""} - ).encode("utf-8") - mock_json_output.metadata = None - mock_response.outputs = [mock_json_output] - mock_api_client.agent_engines.sandboxes.execute_code.return_value = ( - mock_response - ) - - # Ensure session.state behaves like a dict for storing sandbox_name - mock_invocation_context.session.state = {} - - # Execute using agent_engine_resource_name so a sandbox will be created - executor = AgentEngineSandboxCodeExecutor( - agent_engine_resource_name=( - "projects/123/locations/us-central1/reasoningEngines/456" - ), - sandbox_resource_name=None, - ) - code_input = CodeExecutionInput(code='print("hello world")') - result = executor.execute_code(mock_invocation_context, code_input) - - # Assert sandbox creation was called and session state updated - mock_api_client.agent_engines.sandboxes.create.assert_called_once() - create_call_kwargs = ( - mock_api_client.agent_engines.sandboxes.create.call_args.kwargs - ) - assert create_call_kwargs["name"] == ( - "projects/123/locations/us-central1/reasoningEngines/456" - ) - assert executor.sandbox_resource_name == created_sandbox_name - assert ( - mock_invocation_context.session.state["sandbox_name"] - == created_sandbox_name - ) - - # Assert execute_code used the created sandbox name - mock_api_client.agent_engines.sandboxes.execute_code.assert_called_once_with( - name=created_sandbox_name, - input_data={"code": 'print("hello world")'}, - ) From 48105b49c5ab8e4719a66e7219f731b2cd293b00 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Feb 2026 12:58:01 -0800 Subject: [PATCH 4/8] fix: Add support for injecting a custom google.genai.Client into Gemini models This change introduces a new `client` parameter to the `Gemini` model's constructor. When provided, this preconfigured `google.genai.Client` instance is used for all API calls, offering fine-grained control over authentication, project, and location settings Close #2560 PiperOrigin-RevId: 874752355 --- src/google/adk/models/google_llm.py | 56 -------- tests/unittests/models/test_google_llm.py | 150 ---------------------- 2 files changed, 206 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index b8c5117e19..23c9c27810 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -85,23 +85,6 @@ class Gemini(BaseLlm): Attributes: model: The name of the Gemini model. - client: An optional preconfigured ``google.genai.Client`` instance. - When provided, ADK uses this client for all API calls instead of - creating one internally from environment variables or ADC. This - allows fine-grained control over authentication, project, location, - and other client-level settings — and enables running agents that - target different Vertex AI regions within the same process. - - Example:: - - from google import genai - from google.adk.models import Gemini - - client = genai.Client( - vertexai=True, project="my-project", location="us-central1" - ) - model = Gemini(model="gemini-2.5-flash", client=client) - use_interactions_api: Whether to use the interactions API for model invocation. """ @@ -148,35 +131,6 @@ class Gemini(BaseLlm): ``` """ - def __init__(self, *, client: Optional[Client] = None, **kwargs: Any): - """Initialises a Gemini model wrapper. - - Args: - client: An optional preconfigured ``google.genai.Client``. When - provided, ADK uses this client for **all** Gemini API calls - (including the Live API) instead of creating one internally. - - .. note:: - When a custom client is supplied it is used as-is for Live API - connections. ADK will **not** override the client's - ``api_version``; you are responsible for setting the correct - version (``v1beta1`` for Vertex AI, ``v1alpha`` for the - Gemini developer API) on the client yourself. - - .. warning:: - ``google.genai.Client`` contains threading primitives that - cannot be pickled. If you are deploying to Agent Engine (or - any environment that serialises the model), do **not** pass a - custom client — let ADK create one from the environment - instead. - - **kwargs: Forwarded to the Pydantic ``BaseLlm`` constructor - (``model``, ``base_url``, ``retry_options``, etc.). - """ - super().__init__(**kwargs) - # Store after super().__init__ so Pydantic validation runs first. - object.__setattr__(self, '_client', client) - @classmethod @override def supported_models(cls) -> list[str]: @@ -345,16 +299,9 @@ async def _generate_content_via_interactions( def api_client(self) -> Client: """Provides the api client. - If a preconfigured ``client`` was passed to the constructor it is - returned directly; otherwise a new ``Client`` is created using the - default environment/ADC configuration. - Returns: The api client. """ - if self._client is not None: - return self._client - from google.genai import Client return Client( @@ -387,9 +334,6 @@ def _live_api_version(self) -> str: @cached_property def _live_api_client(self) -> Client: - if self._client is not None: - return self._client - from google.genai import Client return Client( diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 75d4c0fd48..70aa01b69d 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -2140,153 +2140,3 @@ async def __aexit__(self, *args): # Verify the final speech_config is still None assert config_arg.speech_config is None assert isinstance(connection, GeminiLlmConnection) - - -# --------------------------------------------------------------------------- -# Tests for custom client injection (Issue #2560) -# --------------------------------------------------------------------------- - - -def test_custom_client_is_used_for_api_client(): - """When a custom client is provided, api_client returns it directly.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini.api_client is custom_client - - -def test_custom_client_is_used_for_live_api_client(): - """When a custom client is provided, _live_api_client returns it directly.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini._live_api_client is custom_client - - -def test_default_api_client_when_no_custom_client(): - """Without a custom client, api_client creates a default Client.""" - gemini = Gemini(model="gemini-1.5-flash") - - # api_client should construct a real Client (not None) - client = gemini.api_client - assert client is not None - # Verify it is not a mock — it's a real google.genai.Client - from google.genai import Client - - assert isinstance(client, Client) - - -def test_default_live_api_client_when_no_custom_client(): - """Without a custom client, _live_api_client creates a default Client.""" - gemini = Gemini(model="gemini-1.5-flash") - - client = gemini._live_api_client - assert client is not None - from google.genai import Client - - assert isinstance(client, Client) - - -def test_custom_client_api_backend_vertexai(): - """_api_backend reflects the custom client's vertexai setting.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = True - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini._api_backend == GoogleLLMVariant.VERTEX_AI - - -def test_custom_client_api_backend_gemini_api(): - """_api_backend reflects non-vertexai custom client.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = False - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - assert gemini._api_backend == GoogleLLMVariant.GEMINI_API - - -@pytest.mark.asyncio -async def test_custom_client_used_for_generate_content(): - """Custom client is used when generate_content_async is called.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = False - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - generate_content_response = types.GenerateContentResponse( - candidates=[ - types.Candidate( - content=Content( - role="model", - parts=[Part.from_text(text="Hello from custom client")], - ), - finish_reason=types.FinishReason.STOP, - ) - ] - ) - - async def mock_coro(): - return generate_content_response - - custom_client.aio.models.generate_content.return_value = mock_coro() - - llm_request = LlmRequest( - model="gemini-1.5-flash", - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], - config=types.GenerateContentConfig( - system_instruction="You are a helpful assistant", - ), - ) - - responses = [ - resp - async for resp in gemini.generate_content_async(llm_request, stream=False) - ] - - assert len(responses) == 1 - assert responses[0].content.parts[0].text == "Hello from custom client" - custom_client.aio.models.generate_content.assert_called_once() - - -@pytest.mark.asyncio -async def test_custom_client_used_for_live_connect(): - """Custom client is used for live API streaming connections.""" - from google.genai import Client - - custom_client = mock.MagicMock(spec=Client) - custom_client.vertexai = False - gemini = Gemini(model="gemini-1.5-flash", client=custom_client) - - mock_live_session = mock.AsyncMock() - - class MockLiveConnect: - - async def __aenter__(self): - return mock_live_session - - async def __aexit__(self, *args): - pass - - custom_client.aio.live.connect.return_value = MockLiveConnect() - - llm_request = LlmRequest( - model="gemini-1.5-flash", - contents=[Content(role="user", parts=[Part.from_text(text="Hello")])], - config=types.GenerateContentConfig( - system_instruction="You are a helpful assistant", - ), - ) - llm_request.live_connect_config = types.LiveConnectConfig() - - async with gemini.connect(llm_request) as connection: - custom_client.aio.live.connect.assert_called_once() - assert isinstance(connection, GeminiLlmConnection) From 121d27741684685c564e484704ae949c5f0807b1 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 24 Feb 2026 13:26:25 -0800 Subject: [PATCH 5/8] feat: Add /chat/completions streaming support to Apigee LLM PiperOrigin-RevId: 874764985 --- src/google/adk/models/apigee_llm.py | 481 ++++++++++++++---- .../models/test_completions_http_client.py | 341 ++++++++++++- 2 files changed, 716 insertions(+), 106 deletions(-) diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 90a91f32d7..fc4928cb25 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -25,6 +25,7 @@ import os from typing import Any from typing import AsyncGenerator +from typing import Generator from typing import Optional from typing import TYPE_CHECKING @@ -51,6 +52,14 @@ _PROJECT_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_PROJECT' _LOCATION_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_LOCATION' +_CUSTOM_METADATA_FIELDS = ( + 'id', + 'created', + 'model', + 'service_tier', + 'object', +) + class ApigeeLlm(Gemini): """A BaseLlm implementation for calling Apigee proxy. @@ -290,6 +299,45 @@ def _get_model_id(model: str) -> str: return components[-1] +def _parse_logprobs( + logprobs_data: dict[str, Any] | None, +) -> types.LogprobsResult | None: + """Parses OpenAI logprobs data into LogprobsResult.""" + if not logprobs_data or 'content' not in logprobs_data: + return None + + chosen_candidates = [] + top_candidates = [] + + for item in logprobs_data['content']: + chosen_candidates.append( + types.LogprobsResultCandidate( + token=item.get('token'), + log_probability=item.get('logprob'), + # OpenAI text format usually doesn't expose ID easily here + token_id=None, + ) + ) + + if 'top_logprobs' in item: + current_top_candidates = [] + for top_item in item['top_logprobs']: + current_top_candidates.append( + types.LogprobsResultCandidate( + token=top_item.get('token'), + log_probability=top_item.get('logprob'), + token_id=None, + ) + ) + top_candidates.append( + types.LogprobsResultTopCandidates(candidates=current_top_candidates) + ) + + return types.LogprobsResult( + chosen_candidates=chosen_candidates, top_candidates=top_candidates + ) + + def _validate_model_string(model: str) -> bool: """Validates the model string for Apigee LLM. @@ -383,7 +431,7 @@ def _cleanup_client(client: httpx.AsyncClient) -> None: loop.create_task(client.aclose()) except RuntimeError: try: - # This fails if aynscio.run is already called in main and is being closed. + # This fails if asyncio.run is already called in main and is closing. asyncio.run(client.aclose()) except RuntimeError: pass @@ -470,7 +518,8 @@ async def generate_content_async( url = f"{url.rstrip('/')}/chat/completions" if stream: - raise NotImplementedError('Streaming is not supported yet.') + async for stream_res in self._handle_streaming(url, payload, headers): + yield stream_res else: response = await self._httpx_post_with_retry(url, payload, headers) data = response.json() @@ -487,11 +536,33 @@ async def _httpx_post_with_retry( response.raise_for_status() return response - async def _handle_streaming_response( - self, response: httpx.Response + async def _handle_streaming( + self, url: str, payload: dict[str, Any], headers: dict[str, str] ) -> AsyncGenerator[LlmResponse, None]: """Handles streaming response from OpenAI-compatible API.""" - raise NotImplementedError('Streaming is not supported yet.') + accumulator = ChatCompletionsResponseHandler() + async with self._client.stream( + 'POST', + url, + json=payload, + headers=headers, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + line = line.strip() + if line.startswith('data:'): + line = line.removeprefix('data:') + line = line.lstrip() + if line == '[DONE]': + break + try: + for res in self._parse_streaming_line(line, accumulator): + yield res + except json.JSONDecodeError: + logger.warning('Failed to parse JSON chunk: %s', line) + continue def _construct_payload( self, llm_request: LlmRequest, stream: bool @@ -731,78 +802,62 @@ def _serialize_system_instruction( return ''.join(part.text for part in parts if part.text) return None - def _parse_logprobs( - self, logprobs_data: dict[str, Any] | None - ) -> types.LogprobsResult | None: - """Parses OpenAI logprobs data into LogprobsResult.""" - if not logprobs_data or 'content' not in logprobs_data: - return None + def _parse_response(self, response: dict[str, Any]) -> LlmResponse: + """Parses an OpenAI response dictionary into an LlmResponse.""" + handler = ChatCompletionsResponseHandler() + return handler.process_response(response) - chosen_candidates = [] - top_candidates = [] - - for item in logprobs_data['content']: - chosen_candidates.append( - types.LogprobsResultCandidate( - token=item.get('token'), - log_probability=item.get('logprob'), - # OpenAI text format usually doesn't expose ID easily here - token_id=None, - ) - ) + def _parse_streaming_line( + self, + line: str, + accumulator: ChatCompletionsResponseHandler, + ) -> Generator[LlmResponse]: + """Parses a single line from the streaming response. - if 'top_logprobs' in item: - current_top_candidates = [] - for top_item in item['top_logprobs']: - current_top_candidates.append( - types.LogprobsResultCandidate( - token=top_item.get('token'), - log_probability=top_item.get('logprob'), - token_id=None, - ) - ) - top_candidates.append( - types.LogprobsResultTopCandidates(candidates=current_top_candidates) - ) + Args: + line: A single line from the streaming response, expected to be a JSON + string. + accumulator: An accumulator to manage partial chat completion choices + across multiple chunks. - return types.LogprobsResult( - chosen_candidates=chosen_candidates, top_candidates=top_candidates - ) + Yields: + An LlmResponse object parsed from the streaming line. + """ + chunk = json.loads(line) + for response in accumulator.process_chunk(chunk): + yield response - def _parse_response(self, response: dict[str, Any]) -> LlmResponse: - """Parses an OpenAI response dictionary into an LlmResponse.""" + +class ChatCompletionsResponseHandler: + """Accumulates responses from the /chat/completions endpoint. + + Useful for both streaming and non-streaming responses. + """ + + def __init__(self): + self.content_parts = '' + self.tool_call_parts = {} + self.role = '' + self.streaming_complete = False + self.model = '' + self.usage = {} + self.logprobs = {} + self.custom_metadata = {} + + def process_response(self, response: dict[str, Any]) -> LlmResponse: + """Processes a complete non-streaming response.""" choices = response.get('choices', []) if not choices: - return LlmResponse() - + raise ValueError('No choices found in response.') + if len(choices) > 1: + logging.error( + 'Multiple choices found in response but only the first one will be' + ' used.' + ) choice = choices[0] message = choice.get('message', {}) - role = message.get('role', 'model') - if role == 'assistant': - role = 'model' - - parts = [] - content_str = message.get('content') - if content_str: - parts.append(types.Part.from_text(text=content_str)) - - tool_calls = message.get('tool_calls') - if tool_calls: - for tool_call in tool_calls: - call_type = tool_call.get('type', 'unknown') - # TODO: Add support for 'custom' type. - if call_type != 'function': - raise ValueError( - f'Unsupported tool_call type: {call_type} in call {tool_call}' - ) - func = tool_call.get('function', {}) - part = self._parse_function_call(func) - parts.append(part) - - function_call = message.get('function_call') - if function_call: - part = self._parse_function_call(function_call) - parts.append(part) + _, role = self._add_chat_completion_message(message) + parts = self._get_content_parts() usage = response.get('usage', {}) usage_metadata = types.GenerateContentResponseUsageMetadata( @@ -810,19 +865,13 @@ def _parse_response(self, response: dict[str, Any]) -> LlmResponse: candidates_token_count=usage.get('completion_tokens', 0), total_token_count=usage.get('total_tokens', 0), ) + logprobs_result = _parse_logprobs(choice.get('logprobs')) - logprobs_result = self._parse_logprobs(choice.get('logprobs')) - - custom_metadata = { - 'id': response.get('id'), - 'created': response.get('created'), - 'model': response.get('model'), - 'system_fingerprint': response.get('system_fingerprint'), - 'service_tier': response.get('service_tier'), - } - custom_metadata = { - k: v for k, v in custom_metadata.items() if v is not None - } + custom_metadata = {} + for k in _CUSTOM_METADATA_FIELDS: + v = response.get(k) + if v is not None: + custom_metadata[k] = v return LlmResponse( content=types.Content(role=role, parts=parts), @@ -833,6 +882,83 @@ def _parse_response(self, response: dict[str, Any]) -> LlmResponse: custom_metadata=custom_metadata, ) + def process_chunk( + self, chunk: dict[str, Any] + ) -> Generator[LlmResponse, None, None]: + """Processes a chunk and yields responses.""" + if 'model' in chunk: + self.model = chunk['model'] + if 'usage' in chunk and chunk['usage']: + self.usage.update(chunk['usage']) + + for k in _CUSTOM_METADATA_FIELDS: + v = chunk.get(k) + if v is not None: + self.custom_metadata[k] = v + + usage_metadata = None + if self.usage: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=self.usage.get('prompt_tokens', 0), + candidates_token_count=self.usage.get('completion_tokens', 0), + total_token_count=self.usage.get('total_tokens', 0), + ) + + choices = chunk.get('choices') + if not choices: + # If no choices, but we have usage or other metadata updates, yield them. + if usage_metadata or self.custom_metadata: + yield LlmResponse( + partial=True, + model_version=self.model, + usage_metadata=usage_metadata, + custom_metadata=self.custom_metadata, + ) + return + + if len(choices) > 1: + logging.error( + 'Multiple choices found in streaming response but only the first one' + ' will be used.' + ) + choice = choices[0] + + # Accumulate logprobs if present + if 'logprobs' in choice and choice['logprobs']: + self._accumulate_logprobs(choice['logprobs']) + + logprobs_result = None + if self.logprobs: + logprobs_result = _parse_logprobs(self.logprobs) + + delta = choice.get('delta', {}) + partial_parts, role = self._add_chat_completion_chunk_delta(delta) + + yield LlmResponse( + partial=True, + content=types.Content(role=role, parts=partial_parts), + model_version=self.model, + usage_metadata=usage_metadata, + custom_metadata=self.custom_metadata, + logprobs_result=logprobs_result, + ) + + finish_reason = choice.get('finish_reason') + if finish_reason: + yield LlmResponse( + content=types.Content( + role=role, + parts=self._get_content_parts(), + ), + finish_reason=self._map_finish_reason(finish_reason), + custom_metadata=self.custom_metadata, + model_version=self.model, + usage_metadata=usage_metadata, + logprobs_result=logprobs_result, + ) + # Exit because the 'finish_reason' chunk is the final chunk. + return + def _map_finish_reason(self, reason: str | None) -> types.FinishReason: if reason == 'stop': return types.FinishReason.STOP @@ -844,25 +970,176 @@ def _map_finish_reason(self, reason: str | None) -> types.FinishReason: return types.FinishReason.SAFETY return types.FinishReason.FINISH_REASON_UNSPECIFIED - def _parse_function_call(self, func: dict[str, Any]) -> types.Part: - """Parses a function call dictionary into a Part.""" - name = func.get('name') - args_str = func.get('arguments', '{}') - try: - args = json.loads(args_str) - except json.JSONDecodeError: - args = {} - tool_part = types.Part.from_function_call(name=name, args=args) - if tool_part.function_call: - tool_part.function_call.id = func.get('id', None) - # Add support for gemini's thought_signature. - thought_signature = ( - func.get('extra_content', {}) - .get('google', {}) - .get('thought_signature', '') + def _accumulate_logprobs(self, logprobs_chunk: dict[str, Any]) -> None: + """Accumulates logprobs from a chunk.""" + if not self.logprobs: + self.logprobs = {'content': [], 'refusal': []} + + if 'content' in logprobs_chunk and logprobs_chunk['content']: + if 'content' not in self.logprobs: + self.logprobs['content'] = [] + self.logprobs['content'].extend(logprobs_chunk['content']) + + if 'refusal' in logprobs_chunk and logprobs_chunk['refusal']: + if 'refusal' not in self.logprobs: + self.logprobs['refusal'] = [] + self.logprobs['refusal'].extend(logprobs_chunk['refusal']) + + def _append_content(self, content: str, refusal: str) -> str: + if content and refusal: + content += '\n' + content += refusal + elif refusal: + content = refusal + if content: + self.content_parts += content + return content + + def _add_chat_completion_chunk_delta( + self, delta: dict[str, Any] + ) -> (list[types.Part], str): + """Adds a chunk delta from a streaming chat completions response. + + This method processes a single delta chunk from a streaming chat completions + response, accumulating partial content and tool calls. + + Args: + delta: A dictionary representing a single delta from the streaming chat + completions API. + + Returns: + A tuple containing: + - A list of `types.Part` objects representing the content and tool calls + in this chunk. + - The role associated with the message. + """ + parts = [] + for tool_call in delta.get('tool_calls', []): + chunk_part = self._upsert_tool_call(tool_call) + parts.append(chunk_part) + content = delta.get('content') + refusal = delta.get('refusal') + merged_content = self._append_content(content, refusal) + if merged_content: + parts.append(types.Part.from_text(text=merged_content)) + + self._get_or_create_role(delta.get('role', 'model')) + return parts, self.role + + def _add_chat_completion_message( + self, message: dict[str, Any] + ) -> (list[types.Part], str): + """Adds a complete chat completion message to the accumulator. + + This method processes a single message from a non-streaming chat completions + response, extracting and accumulating content and tool calls. + + Args: + message: A dictionary representing a single message from the chat + completions API. + + Returns: + A tuple containing: + - A list of `types.Part` objects representing the content and tool calls + in this message. + - The role associated with the message. + """ + for tool_call in message.get('tool_calls', []): + self._upsert_tool_call(tool_call) + function_call = message.get('function_call') + if function_call: + # function_call is a single tool call and does not have an id. + self._upsert_tool_call({ + 'type': 'function', + 'function': function_call, + }) + content = message.get('content') + refusal = message.get('refusal') + self._append_content(content, refusal) + + self._get_or_create_role(message.get('role', 'model')) + return self._get_content_parts(), self.role + + def _get_content_parts(self) -> list[types.Part]: + """Returns the content parts from the accumulated response.""" + parts = [] + if self.content_parts: + parts.append(types.Part.from_text(text=self.content_parts)) + sorted_indices = sorted(self.tool_call_parts.keys()) + for index in sorted_indices: + parts.append(self.tool_call_parts[index]) + return parts + + def _upsert_tool_call(self, tool_call: dict[str, Any]) -> types.Part: + """Upserts a tool call into the accumulated tool call parts. + + This method handles partial tool call chunks in streaming responses by + updating existing tool call parts or creating new ones. + + Args: + tool_call: A dictionary representing a tool call or a delta of a tool call + from the chat completions API. + + Returns: + A `types.Part` object representing the updated or newly created tool call. + """ + index = tool_call.get('index') + if index is None: + # If index is not provided, we might be in a non-streaming response. + # We just append it as a new tool call. + index = len(self.tool_call_parts) + + if index not in self.tool_call_parts: + self.tool_call_parts[index] = types.Part( + function_call=types.FunctionCall() ) - if thought_signature: - if isinstance(thought_signature, str): - thought_signature = base64.b64decode(thought_signature) - tool_part.thought_signature = thought_signature - return tool_part + part = self.tool_call_parts[index] + chunk_part = types.Part(function_call=types.FunctionCall()) + call_type = tool_call.get('type') + # TODO: Add support for 'custom' type. + if call_type is not None and call_type != 'function': + raise ValueError( + f'Unsupported tool_call type: {call_type} in call {tool_call}' + ) + func = tool_call.get('function', {}) + args_delta = func.get('arguments', '') + if args_delta: + try: + args = json.loads(args_delta) + chunk_part.function_call.args = args + if not part.function_call.args: + part.function_call.args = dict(args) + else: + part.function_call.args.update(args) + except json.JSONDecodeError as e: + raise ValueError(f'Failed to parse arguments: {args_delta}') from e + + func_name = func.get('name') + if func_name: + part.function_call.name = func_name + chunk_part.function_call.name = func_name + tool_call_id = tool_call.get('id') + if tool_call_id: + part.function_call.id = tool_call_id + chunk_part.function_call.id = tool_call_id + + # Add support for gemini's thought_signature. + thought_signature = ( + tool_call.get('extra_content', {}) + .get('google', {}) + .get('thought_signature', '') + ) + if thought_signature: + if isinstance(thought_signature, str): + thought_signature = base64.b64decode(thought_signature) + part.thought_signature = thought_signature + chunk_part.thought_signature = thought_signature + return chunk_part + + def _get_or_create_role(self, role: str = '') -> str: + if self.role: + return self.role + if role == 'assistant': + role = 'model' + self.role = role + return self.role diff --git a/tests/unittests/models/test_completions_http_client.py b/tests/unittests/models/test_completions_http_client.py index f16376d7fd..615871eb32 100644 --- a/tests/unittests/models/test_completions_http_client.py +++ b/tests/unittests/models/test_completions_http_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from unittest import mock from unittest.mock import AsyncMock @@ -24,7 +25,7 @@ @pytest.fixture def client(): - return CompletionsHTTPClient(base_url='https://example.com') + return CompletionsHTTPClient(base_url='https://localhost') @pytest.fixture(name='llm_request') @@ -58,7 +59,7 @@ async def test_construct_payload_basic_payload(client, llm_request): url = call_args[0][0] kwargs = call_args[1] - assert url == 'https://example.com/chat/completions' + assert url == 'https://localhost/chat/completions' payload = kwargs['json'] assert payload['model'] == 'open_llama' assert payload['stream'] is False @@ -231,7 +232,7 @@ async def test_construct_payload_image_file_uri(client): role='user', parts=[ types.Part.from_uri( - file_uri='https://example.com/image.jpg', + file_uri='https://localhost/image.jpg', mime_type='image/jpeg', ) ], @@ -263,7 +264,7 @@ async def test_construct_payload_image_file_uri(client): assert isinstance(message['content'], list) assert message['content'][0] == { 'type': 'image_url', - 'image_url': {'url': 'https://example.com/image.jpg'}, + 'image_url': {'url': 'https://localhost/image.jpg'}, } @@ -368,6 +369,7 @@ async def test_construct_payload_response_format( mock_post.assert_called_once() payload = mock_post.call_args[1]['json'] + assert payload['response_format'] == expected_response_format @@ -438,3 +440,334 @@ async def test_generate_content_async_function_call_response( assert part.function_call.name == 'get_weather' assert part.function_call.args == {'location': 'London'} assert part.function_call.id is None + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_function_call(): + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + + # Mock chunks simulating split arguments + chunk_data_0 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'id': 'call_123', + 'type': 'function', + 'function': {'name': 'get_weather', 'arguments': ''}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_1 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'function': {'arguments': '{"location": "London"}'}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_2 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': 0, + 'function': {'arguments': '{"country": "UK"}'}, + }] + }, + 'finish_reason': None, + }], + } + chunk_data_3 = { + 'id': 'chatcmpl-123', + 'object': 'chat.completion.chunk', + 'created': 1234567890, + 'model': 'gpt-3.5-turbo', + 'service_tier': 'default', + 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'tool_calls'}], + 'usage': { + 'prompt_tokens': 10, + 'completion_tokens': 20, + 'total_tokens': 30, + }, + } + + chunks = [ + f'{json.dumps(chunk_data_0)}\n', + f'{json.dumps(chunk_data_1)}\n', + f'{json.dumps(chunk_data_2)}\n', + f'{json.dumps(chunk_data_3)}\n', + ] + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + # Check that we get 5 responses (one per chunk + extra final accumulated) + assert len(responses) == 5 + + # Check 1st response: partial tool call, empty args + assert responses[0].partial is True + assert responses[0].content.parts[0].function_call.name == 'get_weather' + assert responses[0].content.parts[0].function_call.id == 'call_123' + + # Check 2nd response: full args for first update + assert responses[1].partial is True + assert responses[1].content.parts[0].function_call.args == { + 'location': 'London' + } + + # Check 3rd response: full args for second update (merged) + assert responses[2].partial is True + assert responses[2].content.parts[0].function_call.args == {'country': 'UK'} + + # Check 4th response: last delta (empty) + assert responses[3].partial is True + assert responses[3].content.parts == [] + + # Check 5th response: final accumulated + assert responses[4].finish_reason == types.FinishReason.STOP + # Full accumulated args + assert responses[4].content.parts[0].function_call.args == { + 'location': 'London', + 'country': 'UK', + } + + # Check metadata and usage + assert responses[4].model_version == 'gpt-3.5-turbo' + assert responses[4].custom_metadata['id'] == 'chatcmpl-123' + assert responses[4].custom_metadata['created'], 1234567890 + assert responses[4].custom_metadata['object'], 'chat.completion.chunk' + assert responses[4].custom_metadata['service_tier'], 'default' + assert responses[4].usage_metadata is not None + assert responses[4].usage_metadata.prompt_token_count == 10 + assert responses[4].usage_metadata.candidates_token_count == 20 + assert responses[4].usage_metadata.total_token_count == 30 + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_multiple_function_calls(): + # Mock streaming response with multiple tool calls + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + chunk_data_1 = { + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [ + { + 'index': 0, + 'id': 'call_1', + 'type': 'function', + 'function': {'name': 'func_1', 'arguments': ''}, + }, + { + 'index': 1, + 'id': 'call_2', + 'type': 'function', + 'function': {'name': 'func_2', 'arguments': ''}, + }, + ] + }, + 'finish_reason': None, + }] + } + # the tool_call type is optional in chunk responses. + chunk_data_2 = { + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [ + {'index': 0, 'function': {'arguments': '{"arg": 1}'}}, + {'index': 1, 'function': {'arguments': '{"arg": 2}'}}, + ] + }, + 'finish_reason': None, + }] + } + chunk_data_3 = { + 'choices': [{'index': 0, 'delta': {}, 'finish_reason': 'tool_calls'}] + } + + chunks = [ + f'{json.dumps(chunk_data_1)}\n', + f'{json.dumps(chunk_data_2)}\n', + f'{json.dumps(chunk_data_3)}\n', + ] + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + + assert len(responses) == 4 + parts = responses[-1].content.parts + assert len(parts) == 2 + + assert parts[0].function_call.name == 'func_1' + assert parts[0].function_call.args == {'arg': 1} + assert parts[0].function_call.id == 'call_1' + + assert parts[1].function_call.name == 'func_2' + assert parts[1].function_call.args == {'arg': 2} + + assert parts[1].function_call.id == 'call_2' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('chunks', 'expected_response_count'), + [ + ( + [ + '\n', + ' \n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + ], + 1, + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + '[DONE]\n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + ' [DONE] \n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ( + [ + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "Hello"}, "finish_reason": null}]}\n' + ), + 'data: [DONE]\n', + ( + 'data: {"choices": [{"index": 0, "delta": {"content":' + ' "World"}, "finish_reason": "stop"}]}\n' + ), + ], + 1, # Should stop after [DONE] + ), + ], +) +async def test_generate_content_async_streaming_parse_lines( + chunks, expected_response_count +): + local_client = CompletionsHTTPClient(base_url='https://localhost') + llm_request = LlmRequest( + model='apigee/test', + contents=[ + types.Content(role='user', parts=[types.Part.from_text(text='hi')]) + ], + ) + + async def mock_aiter_lines(): + for chunk in chunks: + yield chunk + + mock_response = AsyncMock(spec=httpx.Response) + mock_response.aiter_lines.return_value = mock_aiter_lines() + mock_response.status_code = 200 + + mock_stream_ctx = mock.AsyncMock() + mock_stream_ctx.__aenter__.return_value = mock_response + + with mock.patch.object( + httpx.AsyncClient, 'stream', return_value=mock_stream_ctx + ): + responses = [ + r + async for r in local_client.generate_content_async( + llm_request, stream=True + ) + ] + assert len(responses) == expected_response_count + assert responses[0].content.parts[0].text == 'Hello' From 8f5428150d18ed732b66379c0acb806a9121c3cb Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Tue, 24 Feb 2026 14:34:52 -0800 Subject: [PATCH 6/8] fix: Update sample skills agent to use weather-skill instead of weather_skill Co-authored-by: Kathy Wu PiperOrigin-RevId: 874796345 --- contributing/samples/skills_agent/agent.py | 2 +- .../skills/{weather_skill => weather-skill}/SKILL.md | 0 .../{weather_skill => weather-skill}/references/weather_info.md | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename contributing/samples/skills_agent/skills/{weather_skill => weather-skill}/SKILL.md (100%) rename contributing/samples/skills_agent/skills/{weather_skill => weather-skill}/references/weather_info.md (100%) diff --git a/contributing/samples/skills_agent/agent.py b/contributing/samples/skills_agent/agent.py index 6cd69ffb63..9caf0ad752 100644 --- a/contributing/samples/skills_agent/agent.py +++ b/contributing/samples/skills_agent/agent.py @@ -41,7 +41,7 @@ ) weather_skill = load_skill_from_dir( - pathlib.Path(__file__).parent / "skills" / "weather_skill" + pathlib.Path(__file__).parent / "skills" / "weather-skill" ) my_skill_toolset = SkillToolset(skills=[greeting_skill, weather_skill]) diff --git a/contributing/samples/skills_agent/skills/weather_skill/SKILL.md b/contributing/samples/skills_agent/skills/weather-skill/SKILL.md similarity index 100% rename from contributing/samples/skills_agent/skills/weather_skill/SKILL.md rename to contributing/samples/skills_agent/skills/weather-skill/SKILL.md diff --git a/contributing/samples/skills_agent/skills/weather_skill/references/weather_info.md b/contributing/samples/skills_agent/skills/weather-skill/references/weather_info.md similarity index 100% rename from contributing/samples/skills_agent/skills/weather_skill/references/weather_info.md rename to contributing/samples/skills_agent/skills/weather-skill/references/weather_info.md From e4d9540ce3552ffd3335e1776293eafee4ea28cd Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 24 Feb 2026 23:29:06 -0800 Subject: [PATCH 7/8] chore: Make `Release: Please` workflow only run via workflow_dispatch Co-authored-by: Xuan Yang PiperOrigin-RevId: 874980878 --- .github/workflows/release-please.yml | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/.github/workflows/release-please.yml b/.github/workflows/release-please.yml index 791d84a5b6..41d8d864c2 100644 --- a/.github/workflows/release-please.yml +++ b/.github/workflows/release-please.yml @@ -1,11 +1,10 @@ # Runs release-please to create/update a PR with version bump and changelog. -# Triggered automatically by step 1 (cut) or step 3 (cherry-pick). +# Triggered only by workflow_dispatch (from release-cut.yml). +# Does NOT auto-run on push to preserve manual changelog edits after cherry-picks. name: "Release: Please" on: - push: - branches: - - release/candidate + # Only run via workflow_dispatch (triggered by release-cut.yml) workflow_dispatch: permissions: @@ -14,8 +13,6 @@ permissions: jobs: release-please: - # Skip if this is a release-please PR merge (handled by Release: Finalize) - if: "!startsWith(github.event.head_commit.message, 'chore(release')" runs-on: ubuntu-latest steps: - name: Check if release/candidate still exists From 636f68fbee700aa47f01e2cfd746859353b3333d Mon Sep 17 00:00:00 2001 From: Haiyuan Cao Date: Wed, 25 Feb 2026 00:58:38 -0800 Subject: [PATCH 8/8] feat: Add RunSkillScriptTool to SkillToolset Introduces RunSkillScriptTool to execute scripts located in a skill's scripts/ directory. The execution logic is isolated within a dedicated SkillScriptCodeExecutor wrapper instantiated by RunSkillScriptTool. This wrapper manages script materialization in a temporary directory and executes Python (via runpy) or Shell scripts (returning standard output or JSON-encoded envelopes). This isolation eliminates the need to modify the underlying `BaseCodeExecutor` interface or implementations (`unsafe_local_code_executor`, etc.) to support working directories or file paths. Co-authored-by: Haiyuan Cao PiperOrigin-RevId: 875012237 --- src/google/adk/tools/skill_toolset.py | 366 ++++++++ tests/unittests/tools/test_skill_toolset.py | 894 +++++++++++++++++++- 2 files changed, 1256 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/skill_toolset.py b/src/google/adk/tools/skill_toolset.py index f90dfdb2b1..d13481eba3 100644 --- a/src/google/adk/tools/skill_toolset.py +++ b/src/google/adk/tools/skill_toolset.py @@ -12,16 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=g-import-not-at-top,protected-access + """Toolset for discovering, viewing, and executing agent skills.""" from __future__ import annotations +import asyncio +import json +import logging from typing import Any +from typing import Optional from typing import TYPE_CHECKING from google.genai import types from ..agents.readonly_context import ReadonlyContext +from ..code_executors.base_code_executor import BaseCodeExecutor +from ..code_executors.code_execution_utils import CodeExecutionInput from ..features import experimental from ..features import FeatureName from ..skills import models @@ -33,6 +41,11 @@ if TYPE_CHECKING: from ..models.llm_request import LlmRequest +logger = logging.getLogger("google_adk." + __name__) + +_DEFAULT_SCRIPT_TIMEOUT = 300 +_MAX_SKILL_PAYLOAD_BYTES = 16 * 1024 * 1024 # 16 MB + DEFAULT_SKILL_SYSTEM_INSTRUCTION = """You can use specialized 'skills' to help you with complex tasks. You MUST use the skill tools to interact with these skills. Skills are folders of instructions and resources that extend your capabilities for specialized tasks. Each skill folder contains: @@ -46,6 +59,7 @@ 1. If a skill seems relevant to the current user query, you MUST use the `load_skill` tool with `name=""` to read its full instructions before proceeding. 2. Once you have read the instructions, follow them exactly as documented before replying to the user. For example, If the instruction lists multiple steps, please make sure you complete all of them in order. 3. The `load_skill_resource` tool is for viewing files within a skill's directory (e.g., `references/*`, `assets/*`, `scripts/*`). Do NOT use other tools to access these files. +4. Use `run_skill_script` to run scripts from a skill's `scripts/` directory. Use `load_skill_resource` to view script content first if needed. """ @@ -227,6 +241,340 @@ async def run_async( } +class _SkillScriptCodeExecutor: + """A helper that materializes skill files and executes scripts.""" + + _base_executor: BaseCodeExecutor + _script_timeout: int + + def __init__(self, base_executor: BaseCodeExecutor, script_timeout: int): + self._base_executor = base_executor + self._script_timeout = script_timeout + + async def execute_script_async( + self, + invocation_context: Any, + skill: models.Skill, + script_path: str, + script_args: dict[str, Any], + ) -> dict[str, Any]: + """Prepares and executes the script using the base executor.""" + code = self._build_wrapper_code(skill, script_path, script_args) + if code is None: + if "." in script_path: + ext_msg = f"'.{script_path.rsplit('.', 1)[-1]}'" + else: + ext_msg = "(no extension)" + return { + "error": ( + f"Unsupported script type {ext_msg}." + " Supported types: .py, .sh, .bash" + ), + "error_code": "UNSUPPORTED_SCRIPT_TYPE", + } + + try: + # Execute the self-contained script using the underlying executor + result = await asyncio.to_thread( + self._base_executor.execute_code, + invocation_context, + CodeExecutionInput(code=code), + ) + + stdout = result.stdout + stderr = result.stderr + + # Shell scripts serialize both streams as JSON + # through stdout; parse the envelope if present. + rc = 0 + is_shell = "." in script_path and script_path.rsplit(".", 1)[ + -1 + ].lower() in ("sh", "bash") + if is_shell and stdout: + try: + parsed = json.loads(stdout) + if isinstance(parsed, dict) and parsed.get("__shell_result__"): + stdout = parsed.get("stdout", "") + stderr = parsed.get("stderr", "") + rc = parsed.get("returncode", 0) + if rc != 0 and not stderr: + stderr = f"Exit code {rc}" + except (json.JSONDecodeError, ValueError): + pass + + status = "success" + if rc != 0: + status = "error" + elif stderr and not stdout: + status = "error" + elif stderr: + status = "warning" + + return { + "skill_name": skill.name, + "script_path": script_path, + "stdout": stdout, + "stderr": stderr, + "status": status, + } + except SystemExit as e: + if e.code in (None, 0): + return { + "skill_name": skill.name, + "script_path": script_path, + "stdout": "", + "stderr": "", + "status": "success", + } + return { + "error": ( + f"Failed to execute script '{script_path}':" + f" exited with code {e.code}" + ), + "error_code": "EXECUTION_ERROR", + } + except Exception as e: # pylint: disable=broad-exception-caught + logger.exception( + "Error executing script '%s' from skill '%s'", + script_path, + skill.name, + ) + short_msg = str(e) + if len(short_msg) > 200: + short_msg = short_msg[:200] + "..." + return { + "error": ( + "Failed to execute script" + f" '{script_path}':\n{type(e).__name__}:" + f" {short_msg}" + ), + "error_code": "EXECUTION_ERROR", + } + + def _build_wrapper_code( + self, + skill: models.Skill, + script_path: str, + script_args: dict[str, Any], + ) -> str | None: + """Builds a self-extracting Python script.""" + ext = "" + if "." in script_path: + ext = script_path.rsplit(".", 1)[-1].lower() + + if not script_path.startswith("scripts/"): + script_path = f"scripts/{script_path}" + + files_dict = {} + for ref_name in skill.resources.list_references(): + content = skill.resources.get_reference(ref_name) + if content is not None: + files_dict[f"references/{ref_name}"] = content + + for asset_name in skill.resources.list_assets(): + content = skill.resources.get_asset(asset_name) + if content is not None: + files_dict[f"assets/{asset_name}"] = content + + for scr_name in skill.resources.list_scripts(): + scr = skill.resources.get_script(scr_name) + if scr is not None and scr.src is not None: + files_dict[f"scripts/{scr_name}"] = scr.src + + total_size = sum( + len(v) if isinstance(v, (str, bytes)) else 0 + for v in files_dict.values() + ) + if total_size > _MAX_SKILL_PAYLOAD_BYTES: + logger.warning( + "Skill '%s' resources total %d bytes, exceeding" + " the recommended limit of %d bytes.", + skill.name, + total_size, + _MAX_SKILL_PAYLOAD_BYTES, + ) + + # Build the boilerplate extract string + code_lines = [ + "import os", + "import tempfile", + "import sys", + "import json as _json", + "import subprocess", + "import runpy", + f"_files = {files_dict!r}", + "def _materialize_and_run():", + " _orig_cwd = os.getcwd()", + " with tempfile.TemporaryDirectory() as td:", + " for rel_path, content in _files.items():", + " full_path = os.path.join(td, rel_path)", + " os.makedirs(os.path.dirname(full_path), exist_ok=True)", + " mode = 'wb' if isinstance(content, bytes) else 'w'", + " with open(full_path, mode) as f:", + " f.write(content)", + " os.chdir(td)", + " try:", + ] + + if ext == "py": + argv_list = [script_path] + for k, v in script_args.items(): + argv_list.extend([f"--{k}", str(v)]) + code_lines.extend([ + f" sys.argv = {argv_list!r}", + " try:", + f" runpy.run_path({script_path!r}, run_name='__main__')", + " except SystemExit as e:", + " if e.code is not None and e.code != 0:", + " raise e", + ]) + elif ext in ("sh", "bash"): + arr = ["bash", script_path] + for k, v in script_args.items(): + arr.extend([f"--{k}", str(v)]) + timeout = self._script_timeout + code_lines.extend([ + " try:", + " _r = subprocess.run(", + f" {arr!r},", + " capture_output=True, text=True,", + f" timeout={timeout!r}, cwd=td,", + " )", + " print(_json.dumps({", + " '__shell_result__': True,", + " 'stdout': _r.stdout,", + " 'stderr': _r.stderr,", + " 'returncode': _r.returncode,", + " }))", + " except subprocess.TimeoutExpired as _e:", + " print(_json.dumps({", + " '__shell_result__': True,", + " 'stdout': _e.stdout or '',", + f" 'stderr': 'Timed out after {timeout}s',", + " 'returncode': -1,", + " }))", + ]) + else: + return None + + code_lines.extend([ + " finally:", + " os.chdir(_orig_cwd)", + ]) + + code_lines.append("_materialize_and_run()") + return "\n".join(code_lines) + + +@experimental(FeatureName.SKILL_TOOLSET) +class RunSkillScriptTool(BaseTool): + """Tool to execute scripts from a skill's scripts/ directory.""" + + def __init__(self, toolset: "SkillToolset"): + super().__init__( + name="run_skill_script", + description="Executes a script from a skill's scripts/ directory.", + ) + self._toolset = toolset + + def _get_declaration(self) -> types.FunctionDeclaration | None: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + "type": "object", + "properties": { + "skill_name": { + "type": "string", + "description": "The name of the skill.", + }, + "script_path": { + "type": "string", + "description": ( + "The relative path to the script (e.g.," + " 'scripts/setup.py')." + ), + }, + "args": { + "type": "object", + "description": ( + "Optional arguments to pass to the script as key-value" + " pairs." + ), + }, + }, + "required": ["skill_name", "script_path"], + }, + ) + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + skill_name = args.get("skill_name") + script_path = args.get("script_path") + script_args = args.get("args", {}) + if not isinstance(script_args, dict): + return { + "error": ( + "'args' must be a JSON object (key-value pairs)," + f" got {type(script_args).__name__}." + ), + "error_code": "INVALID_ARGS_TYPE", + } + + if not skill_name: + return { + "error": "Skill name is required.", + "error_code": "MISSING_SKILL_NAME", + } + if not script_path: + return { + "error": "Script path is required.", + "error_code": "MISSING_SCRIPT_PATH", + } + + skill = self._toolset._get_skill(skill_name) + if not skill: + return { + "error": f"Skill '{skill_name}' not found.", + "error_code": "SKILL_NOT_FOUND", + } + + script = None + if script_path.startswith("scripts/"): + script = skill.resources.get_script(script_path[len("scripts/") :]) + else: + script = skill.resources.get_script(script_path) + + if script is None: + return { + "error": f"Script '{script_path}' not found in skill '{skill_name}'.", + "error_code": "SCRIPT_NOT_FOUND", + } + + # Resolve code executor: toolset-level first, then agent fallback + code_executor = self._toolset._code_executor + if code_executor is None: + agent = tool_context._invocation_context.agent + if hasattr(agent, "code_executor"): + code_executor = agent.code_executor + if code_executor is None: + return { + "error": ( + "No code executor configured. A code executor is" + " required to run scripts." + ), + "error_code": "NO_CODE_EXECUTOR", + } + + script_executor = _SkillScriptCodeExecutor( + code_executor, self._toolset._script_timeout # pylint: disable=protected-access + ) + return await script_executor.execute_script_async( + tool_context._invocation_context, skill, script_path, script_args # pylint: disable=protected-access + ) + + @experimental(FeatureName.SKILL_TOOLSET) class SkillToolset(BaseToolset): """A toolset for managing and interacting with agent skills.""" @@ -234,7 +582,19 @@ class SkillToolset(BaseToolset): def __init__( self, skills: list[models.Skill], + *, + code_executor: Optional[BaseCodeExecutor] = None, + script_timeout: int = _DEFAULT_SCRIPT_TIMEOUT, ): + """Initializes the SkillToolset. + + Args: + skills: List of skills to register. + code_executor: Optional code executor for script execution. + script_timeout: Timeout in seconds for shell script execution via + subprocess.run. Defaults to 300 seconds. Does not apply to Python + scripts executed via exec(). + """ super().__init__() # Check for duplicate skill names @@ -245,11 +605,17 @@ def __init__( seen.add(skill.name) self._skills = {skill.name: skill for skill in skills} + self._code_executor = code_executor + self._script_timeout = script_timeout + + # Initialize core skill tools self._tools = [ ListSkillsTool(self), LoadSkillTool(self), LoadSkillResourceTool(self), ] + # Always add RunSkillScriptTool, relies on invocation_context fallback if _code_executor is None + self._tools.append(RunSkillScriptTool(self)) async def get_tools( self, readonly_context: ReadonlyContext | None = None diff --git a/tests/unittests/tools/test_skill_toolset.py b/tests/unittests/tools/test_skill_toolset.py index 066eedfb67..6532332435 100644 --- a/tests/unittests/tools/test_skill_toolset.py +++ b/tests/unittests/tools/test_skill_toolset.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=redefined-outer-name,g-import-not-at-top,protected-access + + from unittest import mock +from google.adk.code_executors.base_code_executor import BaseCodeExecutor +from google.adk.code_executors.code_execution_utils import CodeExecutionResult from google.adk.models import llm_request as llm_request_model from google.adk.skills import models from google.adk.tools import skill_toolset @@ -27,6 +32,7 @@ def mock_skill1_frontmatter(): frontmatter = mock.create_autospec(models.Frontmatter, instance=True) frontmatter.name = "skill1" frontmatter.description = "Skill 1 description" + frontmatter.allowed_tools = ["test_tool"] frontmatter.model_dump.return_value = { "name": "skill1", "description": "Skill 1 description", @@ -43,7 +49,14 @@ def mock_skill1(mock_skill1_frontmatter): skill.instructions = "instructions for skill1" skill.frontmatter = mock_skill1_frontmatter skill.resources = mock.MagicMock( - spec=["get_reference", "get_asset", "get_script"] + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] ) def get_ref(name): @@ -59,11 +72,22 @@ def get_asset(name): def get_script(name): if name == "setup.sh": return models.Script(src="echo setup") + if name == "run.py": + return models.Script(src="print('hello')") + if name == "build.rb": + return models.Script(src="puts 'hello'") return None skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset skill.resources.get_script.side_effect = get_script + skill.resources.list_references.return_value = ["ref1.md"] + skill.resources.list_assets.return_value = ["asset1.txt"] + skill.resources.list_scripts.return_value = [ + "setup.sh", + "run.py", + "build.rb", + ] return skill @@ -73,6 +97,7 @@ def mock_skill2_frontmatter(): frontmatter = mock.create_autospec(models.Frontmatter, instance=True) frontmatter.name = "skill2" frontmatter.description = "Skill 2 description" + frontmatter.allowed_tools = [] frontmatter.model_dump.return_value = { "name": "skill2", "description": "Skill 2 description", @@ -89,7 +114,14 @@ def mock_skill2(mock_skill2_frontmatter): skill.instructions = "instructions for skill2" skill.frontmatter = mock_skill2_frontmatter skill.resources = mock.MagicMock( - spec=["get_reference", "get_asset", "get_script"] + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] ) def get_ref(name): @@ -104,6 +136,9 @@ def get_asset(name): skill.resources.get_reference.side_effect = get_ref skill.resources.get_asset.side_effect = get_asset + skill.resources.list_references.return_value = ["ref2.md"] + skill.resources.list_assets.return_value = ["asset2.txt"] + skill.resources.list_scripts.return_value = [] return skill @@ -132,13 +167,13 @@ def test_list_skills(mock_skill1, mock_skill2): async def test_get_tools(mock_skill1, mock_skill2): toolset = skill_toolset.SkillToolset([mock_skill1, mock_skill2]) tools = await toolset.get_tools() - assert len(tools) == 3 + assert len(tools) == 4 assert isinstance(tools[0], skill_toolset.ListSkillsTool) assert isinstance(tools[1], skill_toolset.LoadSkillTool) assert isinstance(tools[2], skill_toolset.LoadSkillResourceTool) + assert isinstance(tools[3], skill_toolset.RunSkillScriptTool) -@pytest.mark.asyncio @pytest.mark.asyncio async def test_list_skills_tool( mock_skill1, mock_skill2, tool_context_instance @@ -308,3 +343,854 @@ async def test_scripts_resource_not_found(mock_skill1, tool_context_instance): tool_context=tool_context_instance, ) assert result["error_code"] == "RESOURCE_NOT_FOUND" + + +# RunSkillScriptTool tests + + +def _make_tool_context_with_agent(agent=None): + """Creates a mock ToolContext with _invocation_context.agent.""" + ctx = mock.MagicMock(spec=tool_context.ToolContext) + ctx._invocation_context = mock.MagicMock() + ctx._invocation_context.agent = agent or mock.MagicMock() + return ctx + + +def _make_mock_executor(stdout="", stderr=""): + """Creates a mock code executor that returns the given output.""" + executor = mock.create_autospec(BaseCodeExecutor, instance=True) + executor.execute_code.return_value = CodeExecutionResult( + stdout=stdout, stderr=stderr + ) + return executor + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "args, expected_error_code", + [ + ( + {"script_path": "setup.sh"}, + "MISSING_SKILL_NAME", + ), + ( + {"skill_name": "skill1"}, + "MISSING_SCRIPT_PATH", + ), + ( + {"skill_name": "", "script_path": "setup.sh"}, + "MISSING_SKILL_NAME", + ), + ( + {"skill_name": "skill1", "script_path": ""}, + "MISSING_SCRIPT_PATH", + ), + ], +) +async def test_execute_script_missing_params( + mock_skill1, args, expected_error_code +): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async(args=args, tool_context=ctx) + assert result["error_code"] == expected_error_code + + +@pytest.mark.asyncio +async def test_execute_script_skill_not_found(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "nonexistent", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "SKILL_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_execute_script_script_not_found(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "nonexistent.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "SCRIPT_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_execute_script_no_code_executor(mock_skill1): + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + # Agent without code_executor attribute + agent = mock.MagicMock(spec=[]) + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "NO_CODE_EXECUTOR" + + +@pytest.mark.asyncio +async def test_execute_script_agent_code_executor_none(mock_skill1): + """Agent has code_executor attr but it's None.""" + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = None + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["error_code"] == "NO_CODE_EXECUTOR" + + +@pytest.mark.asyncio +async def test_execute_script_unsupported_type(mock_skill1): + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "build.rb"}, + tool_context=ctx, + ) + assert result["error_code"] == "UNSUPPORTED_SCRIPT_TYPE" + + +@pytest.mark.asyncio +async def test_execute_script_python_success(mock_skill1): + executor = _make_mock_executor(stdout="hello\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello\n" + assert result["stderr"] == "" + assert result["skill_name"] == "skill1" + assert result["script_path"] == "run.py" + + # Verify the code passed to executor runs the python scripts + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "_materialize_and_run()" in code_input.code + assert "import runpy" in code_input.code + assert "sys.argv = ['scripts/run.py']" in code_input.code + assert ( + "runpy.run_path('scripts/run.py', run_name='__main__')" in code_input.code + ) + + +@pytest.mark.asyncio +async def test_execute_script_shell_success(mock_skill1): + executor = _make_mock_executor(stdout="setup\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "setup\n" + + # Verify the code wraps in subprocess.run with JSON envelope + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "subprocess.run" in code_input.code + assert "bash" in code_input.code + assert "__shell_result__" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_with_input_args_python(mock_skill1): + executor = _make_mock_executor(stdout="done\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "run.py", + "args": {"verbose": True, "count": "3"}, + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert ( + "['scripts/run.py', '--verbose', 'True', '--count', '3']" + in code_input.code + ) + + +@pytest.mark.asyncio +async def test_execute_script_with_input_args_shell(mock_skill1): + executor = _make_mock_executor(stdout="done\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "setup.sh", + "args": {"force": True}, + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "['bash', 'scripts/setup.sh', '--force', 'True']" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_scripts_prefix_stripping(mock_skill1): + executor = _make_mock_executor(stdout="setup\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "scripts/setup.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["script_path"] == "scripts/setup.sh" + + +@pytest.mark.asyncio +async def test_execute_script_toolset_executor_priority(mock_skill1): + """Toolset-level executor takes priority over agent's.""" + toolset_executor = _make_mock_executor(stdout="from toolset\n") + agent_executor = _make_mock_executor(stdout="from agent\n") + toolset = skill_toolset.SkillToolset( + [mock_skill1], code_executor=toolset_executor + ) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = agent_executor + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["stdout"] == "from toolset\n" + toolset_executor.execute_code.assert_called_once() + agent_executor.execute_code.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_script_agent_executor_fallback(mock_skill1): + """Falls back to agent's code executor when toolset has none.""" + agent_executor = _make_mock_executor(stdout="from agent\n") + toolset = skill_toolset.SkillToolset([mock_skill1]) + tool = skill_toolset.RunSkillScriptTool(toolset) + agent = mock.MagicMock() + agent.code_executor = agent_executor + ctx = _make_tool_context_with_agent(agent=agent) + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["stdout"] == "from agent\n" + agent_executor.execute_code.assert_called_once() + + +@pytest.mark.asyncio +async def test_execute_script_execution_error(mock_skill1): + executor = _make_mock_executor() + executor.execute_code.side_effect = RuntimeError("boom") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + assert "boom" in result["error"] + assert result["error"].startswith("Failed to execute script 'run.py':") + + +@pytest.mark.asyncio +async def test_execute_script_stderr_only_sets_error_status(mock_skill1): + """stderr with no stdout should report error status.""" + executor = _make_mock_executor(stdout="", stderr="fatal error\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert result["stderr"] == "fatal error\n" + + +@pytest.mark.asyncio +async def test_execute_script_stderr_with_stdout_sets_warning(mock_skill1): + """stderr alongside stdout should report warning status.""" + executor = _make_mock_executor(stdout="output\n", stderr="deprecation\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert result["stdout"] == "output\n" + assert result["stderr"] == "deprecation\n" + + +@pytest.mark.asyncio +async def test_execute_script_execution_error_truncated(mock_skill1): + """Long exception messages are truncated to avoid wasting LLM tokens.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = RuntimeError("x" * 300) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + # 200 chars of the message + "..." suffix + the prefix + assert result["error"].endswith("...") + assert len(result["error"]) < 300 + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_caught(mock_skill1): + """sys.exit() in a script should not terminate the process.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(1) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["error_code"] == "EXECUTION_ERROR" + assert "exited with code 1" in result["error"] + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_zero_is_success(mock_skill1): + """sys.exit(0) is a normal termination and should report success.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(0) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_execute_script_system_exit_none_is_success(mock_skill1): + """sys.exit() with no arg (None) should report success.""" + executor = _make_mock_executor() + executor.execute_code.side_effect = SystemExit(None) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_execute_script_shell_includes_timeout(mock_skill1): + """Shell wrapper includes timeout in subprocess.run.""" + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset( + [mock_skill1], code_executor=executor, script_timeout=60 + ) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "timeout=60" in code_input.code + + +@pytest.mark.asyncio +async def test_execute_script_extensionless_unsupported(mock_skill1): + """Files without extensions should return UNSUPPORTED_SCRIPT_TYPE.""" + # Add a script with no extension to the mock + original_side_effect = mock_skill1.resources.get_script.side_effect + + def get_script_extended(name): + if name == "noext": + return models.Script(src="print('hi')") + return original_side_effect(name) + + mock_skill1.resources.get_script.side_effect = get_script_extended + + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "noext"}, + tool_context=ctx, + ) + assert result["error_code"] == "UNSUPPORTED_SCRIPT_TYPE" + + +# ── Integration tests using real UnsafeLocalCodeExecutor ── + + +def _make_skill_with_script(skill_name, script_name, script): + """Creates a minimal mock Skill with a single script.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = skill_name + skill.description = f"Test skill {skill_name}" + skill.instructions = "test instructions" + fm = mock.create_autospec(models.Frontmatter, instance=True) + fm.name = skill_name + fm.description = f"Test skill {skill_name}" + skill.frontmatter = fm + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + + def get_script(name): + if name == script_name: + return script + return None + + skill.resources.get_script.side_effect = get_script + skill.resources.get_reference.return_value = None + skill.resources.get_asset.return_value = None + skill.resources.list_references.return_value = [] + skill.resources.list_assets.return_value = [] + skill.resources.list_scripts.return_value = [script_name] + return skill + + +def _make_real_executor_toolset(skills, **kwargs): + """Creates a SkillToolset with a real UnsafeLocalCodeExecutor.""" + from google.adk.code_executors.unsafe_local_code_executor import UnsafeLocalCodeExecutor + + executor = UnsafeLocalCodeExecutor() + return skill_toolset.SkillToolset(skills, code_executor=executor, **kwargs) + + +@pytest.mark.asyncio +async def test_integration_python_stdout(): + """Real executor: Python script stdout is captured.""" + script = models.Script(src="print('hello world')") + skill = _make_skill_with_script("test_skill", "hello.py", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "hello.py", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello world\n" + assert result["stderr"] == "" + + +@pytest.mark.asyncio +async def test_integration_python_sys_exit_zero(): + """Real executor: sys.exit(0) is treated as success.""" + script = models.Script(src="import sys; sys.exit(0)") + skill = _make_skill_with_script("test_skill", "exit_zero.py", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "exit_zero.py", + }, + tool_context=ctx, + ) + assert result["status"] == "success" + + +@pytest.mark.asyncio +async def test_integration_shell_stdout_and_stderr(): + """Real executor: shell script preserves both stdout and stderr.""" + script = models.Script(src="echo output; echo warning >&2") + skill = _make_skill_with_script("test_skill", "both.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "both.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert "output" in result["stdout"] + assert "warning" in result["stderr"] + + +@pytest.mark.asyncio +async def test_integration_shell_stderr_only(): + """Real executor: shell script with only stderr reports error.""" + script = models.Script(src="echo failure >&2") + skill = _make_skill_with_script("test_skill", "err.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "err.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "failure" in result["stderr"] + + +# ── Shell JSON envelope parsing (unit tests with mock executor) ── + + +@pytest.mark.asyncio +async def test_shell_json_envelope_parsed(mock_skill1): + """Shell JSON envelope is correctly unpacked by run_async.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "hello from shell\n", + "stderr": "", + "returncode": 0, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "hello from shell\n" + assert result["stderr"] == "" + + +@pytest.mark.asyncio +async def test_shell_json_envelope_nonzero_returncode(mock_skill1): + """Non-zero returncode in shell envelope sets stderr.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "", + "stderr": "", + "returncode": 2, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "Exit code 2" in result["stderr"] + + +@pytest.mark.asyncio +async def test_shell_json_envelope_with_stderr(mock_skill1): + """Shell envelope with both stdout and stderr reports warning.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "data\n", + "stderr": "deprecation warning\n", + "returncode": 0, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "warning" + assert result["stdout"] == "data\n" + assert result["stderr"] == "deprecation warning\n" + + +@pytest.mark.asyncio +async def test_shell_json_envelope_timeout(mock_skill1): + """Shell envelope from TimeoutExpired reports error status.""" + import json + + envelope = json.dumps({ + "__shell_result__": True, + "stdout": "partial output\n", + "stderr": "Timed out after 300s", + "returncode": -1, + }) + executor = _make_mock_executor(stdout=envelope) + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "error" + assert result["stdout"] == "partial output\n" + assert "Timed out" in result["stderr"] + + +@pytest.mark.asyncio +async def test_shell_non_json_stdout_passthrough(mock_skill1): + """Non-JSON shell stdout is passed through without parsing.""" + executor = _make_mock_executor(stdout="plain text output\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={"skill_name": "skill1", "script_path": "setup.sh"}, + tool_context=ctx, + ) + assert result["status"] == "success" + assert result["stdout"] == "plain text output\n" + + +# ── input_files packaging ── + + +@pytest.mark.asyncio +async def test_execute_script_input_files_packaged(mock_skill1): + """Verify references, assets, and scripts are packaged inside the wrapper code.""" + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill1", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + + # input_files is no longer populated; it's serialized inside the script + assert code_input.input_files is None or len(code_input.input_files) == 0 + + # Ensure the extracted literal contains our fake files + assert "references/ref1.md" in code_input.code + assert "assets/asset1.txt" in code_input.code + assert "scripts/setup.sh" in code_input.code + assert "scripts/run.py" in code_input.code + assert "scripts/build.rb" in code_input.code + + # Verify content mappings exist in the string + assert "'references/ref1.md': 'ref content 1'" in code_input.code + assert "'assets/asset1.txt': 'asset content 1'" in code_input.code + + +# ── Integration: shell non-zero exit ── + + +@pytest.mark.asyncio +async def test_integration_shell_nonzero_exit(): + """Real executor: shell script with non-zero exit via JSON envelope.""" + script = models.Script(src="exit 42") + skill = _make_skill_with_script("test_skill", "fail.sh", script) + toolset = _make_real_executor_toolset([skill]) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "test_skill", + "script_path": "fail.sh", + }, + tool_context=ctx, + ) + assert result["status"] == "error" + assert "42" in result["stderr"] + + +# ── Finding 1: system instruction references correct tool name ── + + +def test_system_instruction_references_run_skill_script(): + """System instruction must reference the actual tool name.""" + assert "run_skill_script" in skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + assert ( + "execute_skill_script" + not in skill_toolset.DEFAULT_SKILL_SYSTEM_INSTRUCTION + ) + + +# ── Finding 2: empty files are mounted (not silently dropped) ── + + +@pytest.mark.asyncio +async def test_execute_script_empty_files_mounted(): + """Verify empty files are included in wrapper code, not dropped.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = "skill_empty" + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + skill.resources.get_reference.side_effect = ( + lambda n: "" if n == "empty.md" else None + ) + skill.resources.get_asset.side_effect = ( + lambda n: "" if n == "empty.cfg" else None + ) + skill.resources.get_script.side_effect = ( + lambda n: models.Script(src="") if n == "run.py" else None + ) + skill.resources.list_references.return_value = ["empty.md"] + skill.resources.list_assets.return_value = ["empty.cfg"] + skill.resources.list_scripts.return_value = ["run.py"] + + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([skill], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill_empty", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + assert "'references/empty.md': ''" in code_input.code + assert "'assets/empty.cfg': ''" in code_input.code + assert "'scripts/run.py': ''" in code_input.code + + +# ── Finding 3: invalid args type returns clear error ── + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "bad_args", + [ + "not a dict", + ["a", "list"], + 42, + True, + ], +) +async def test_execute_script_invalid_args_type(mock_skill1, bad_args): + """Non-dict args should return INVALID_ARGS_TYPE, not crash.""" + executor = _make_mock_executor() + toolset = skill_toolset.SkillToolset([mock_skill1], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + result = await tool.run_async( + args={ + "skill_name": "skill1", + "script_path": "run.py", + "args": bad_args, + }, + tool_context=ctx, + ) + assert result["error_code"] == "INVALID_ARGS_TYPE" + executor.execute_code.assert_not_called() + + +# ── Finding 4: binary file content is handled in wrapper ── + + +@pytest.mark.asyncio +async def test_execute_script_binary_content_packaged(): + """Verify binary asset content uses 'wb' mode in wrapper code.""" + skill = mock.create_autospec(models.Skill, instance=True) + skill.name = "skill_bin" + skill.resources = mock.MagicMock( + spec=[ + "get_reference", + "get_asset", + "get_script", + "list_references", + "list_assets", + "list_scripts", + ] + ) + skill.resources.get_reference.side_effect = ( + lambda n: b"\x00\x01\x02" if n == "data.bin" else None + ) + skill.resources.get_asset.return_value = None + skill.resources.get_script.side_effect = lambda n: ( + models.Script(src="print('ok')") if n == "run.py" else None + ) + skill.resources.list_references.return_value = ["data.bin"] + skill.resources.list_assets.return_value = [] + skill.resources.list_scripts.return_value = ["run.py"] + + executor = _make_mock_executor(stdout="ok\n") + toolset = skill_toolset.SkillToolset([skill], code_executor=executor) + tool = skill_toolset.RunSkillScriptTool(toolset) + ctx = _make_tool_context_with_agent() + await tool.run_async( + args={"skill_name": "skill_bin", "script_path": "run.py"}, + tool_context=ctx, + ) + + call_args = executor.execute_code.call_args + code_input = call_args[0][1] + # Binary content should appear as bytes literal + assert "b'\\x00\\x01\\x02'" in code_input.code + # Wrapper code handles binary with 'wb' mode + assert "'wb' if isinstance(content, bytes)" in code_input.code