# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """[Preview] Live API client.""" import asyncio import base64 import contextlib import json import logging import typing from typing import Any, AsyncIterator, Optional, Sequence, Union, get_args import warnings import google.auth import pydantic from websockets import ConnectionClosed from . import _api_module from . import _common from . import _live_converters as live_converters from . import _mcp_utils from . import _transformers as t from . import errors from . import types from ._api_client import BaseApiClient from ._common import get_value_by_path as getv from ._common import set_value_by_path as setv from .live_music import AsyncLiveMusic from .models import _Content_to_mldev try: from websockets.asyncio.client import ClientConnection from websockets.asyncio.client import connect as ws_connect except ModuleNotFoundError: # This try/except is for TAP, mypy complains about it which is why we have the type: ignore from websockets.client import ClientConnection # type: ignore from websockets.client import connect as ws_connect # type: ignore try: from google.auth.transport import requests except ImportError: requests = None # type: ignore[assignment] if typing.TYPE_CHECKING: from mcp import ClientSession as McpClientSession from mcp.types import Tool as McpTool from ._adapters import McpToGenAiToolAdapter from ._mcp_utils import mcp_to_gemini_tool else: McpClientSession: typing.Type = Any McpTool: typing.Type = Any McpToGenAiToolAdapter: typing.Type = Any try: from mcp import ClientSession as McpClientSession from mcp.types import Tool as McpTool from ._adapters import McpToGenAiToolAdapter from ._mcp_utils import mcp_to_gemini_tool except ImportError: McpClientSession = None McpTool = None McpToGenAiToolAdapter = None mcp_to_gemini_tool = None logger = logging.getLogger('google_genai.live') _FUNCTION_RESPONSE_REQUIRES_ID = ( 'FunctionResponse request must have an `id` field from the' ' response of a ToolCall.FunctionalCalls in Google AI.' ) class AsyncSession: """[Preview] AsyncSession.""" def __init__( self, api_client: BaseApiClient, websocket: ClientConnection, session_id: Optional[str] = None, ): self._api_client = api_client self._ws = websocket self.session_id = session_id async def send( self, *, input: Optional[ Union[ types.ContentListUnion, types.ContentListUnionDict, types.LiveClientContentOrDict, types.LiveClientRealtimeInputOrDict, types.LiveClientToolResponseOrDict, types.FunctionResponseOrDict, Sequence[types.FunctionResponseOrDict], ] ] = None, end_of_turn: Optional[bool] = False, ) -> None: """[Deprecated] Send input to the model. > **Warning**: This method is deprecated and will be removed in a future version (not before Q3 2025). Please use one of the more specific methods: `send_client_content`, `send_realtime_input`, or `send_tool_response` instead. The method will send the input request to the server. Args: input: The input request to the model. end_of_turn: Whether the input is the last message in a turn. Example usage: .. code-block:: python client = genai.Client(api_key=API_KEY) async with client.aio.live.connect(model='...') as session: await session.send(input='Hello world!', end_of_turn=True) async for message in session.receive(): print(message) """ warnings.warn( 'The `session.send` method is deprecated and will be removed in a ' 'future version (not before Q3 2025).\n' 'Please use one of the more specific methods: `send_client_content`, ' '`send_realtime_input`, or `send_tool_response` instead.', DeprecationWarning, stacklevel=2, ) client_message = self._parse_client_message(input, end_of_turn) await self._ws.send(json.dumps(client_message)) async def send_client_content( self, *, turns: Optional[ Union[ types.Content, types.ContentDict, list[Union[types.Content, types.ContentDict]], ] ] = None, turn_complete: bool = True, ) -> None: """Send non-realtime, turn based content to the model. There are two ways to send messages to the live API: `send_client_content` and `send_realtime_input`. `send_client_content` messages are added to the model context **in order**. Having a conversation using `send_client_content` messages is roughly equivalent to using the `Chat.send_message_stream` method, except that the state of the `chat` history is stored on the API server. Because of `send_client_content`'s order guarantee, the model cannot respond as quickly to `send_client_content` messages as to `send_realtime_input` messages. This makes the biggest difference when sending objects that have significant preprocessing time (typically images). The `send_client_content` message sends a list of `Content` objects, which has more options than the `media:Blob` sent by `send_realtime_input`. The main use-cases for `send_client_content` over `send_realtime_input` are: - Prefilling a conversation context (including sending anything that can't be represented as a realtime message), before starting a realtime conversation. - Conducting a non-realtime conversation, similar to `client.chat`, using the live api. Caution: Interleaving `send_client_content` and `send_realtime_input` in the same conversation is not recommended and can lead to unexpected results. Args: turns: A `Content` object or list of `Content` objects (or equivalent dicts). turn_complete: if true (the default) the model will reply immediately. If false, the model will wait for you to send additional client_content, and will not return until you send `turn_complete=True`. Example: .. code-block:: python import google.genai from google.genai import types import os if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'): MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09' else: MODEL_NAME = 'gemini-live-2.5-flash-preview'; client = genai.Client() async with client.aio.live.connect( model=MODEL_NAME, config={"response_modalities": ["TEXT"]} ) as session: await session.send_client_content( turns=types.Content( role='user', parts=[types.Part(text="Hello world!")])) async for msg in session.receive(): if msg.text: print(msg.text) """ client_content = t.t_client_content(turns, turn_complete).model_dump( mode='json', exclude_none=True ) if self._api_client.vertexai: client_content_dict = _common.convert_to_dict( client_content, convert_keys=True ) else: client_content_dict = live_converters._LiveClientContent_to_mldev( from_object=client_content ) await self._ws.send(json.dumps({'client_content': client_content_dict})) async def send_realtime_input( self, *, media: Optional[types.BlobImageUnionDict] = None, audio: Optional[types.BlobOrDict] = None, audio_stream_end: Optional[bool] = None, video: Optional[types.BlobImageUnionDict] = None, text: Optional[str] = None, activity_start: Optional[types.ActivityStartOrDict] = None, activity_end: Optional[types.ActivityEndOrDict] = None, ) -> None: """Send realtime input to the model, only send one argument per call. Use `send_realtime_input` for realtime audio chunks and video frames(images). With `send_realtime_input` the api will respond to audio automatically based on voice activity detection (VAD). `send_realtime_input` is optimized for responsivness at the expense of deterministic ordering. Audio and video tokens are added to the context when they become available. Args: media: A `Blob`-like object, the realtime media to send. Example: .. code-block:: python from pathlib import Path from google import genai from google.genai import types import PIL.Image import os if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'): MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09' else: MODEL_NAME = 'gemini-live-2.5-flash-preview'; client = genai.Client() async with client.aio.live.connect( model=MODEL_NAME, config={"response_modalities": ["TEXT"]}, ) as session: await session.send_realtime_input( media=PIL.Image.open('image.jpg')) audio_bytes = Path('audio.pcm').read_bytes() await session.send_realtime_input( media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000')) async for msg in session.receive(): if msg.text is not None: print(f'{msg.text}') """ kwargs: _common.StringDict = {} if media is not None: kwargs['media'] = media if audio is not None: kwargs['audio'] = audio if audio_stream_end is not None: kwargs['audio_stream_end'] = audio_stream_end if video is not None: kwargs['video'] = video if text is not None: kwargs['text'] = text if activity_start is not None: kwargs['activity_start'] = activity_start if activity_end is not None: kwargs['activity_end'] = activity_end if len(kwargs) != 1: raise ValueError( f'Only one argument can be set, got {len(kwargs)}:' f' {list(kwargs.keys())}' ) realtime_input = types.LiveSendRealtimeInputParameters.model_validate( kwargs ) if self._api_client.vertexai: realtime_input_dict = ( live_converters._LiveSendRealtimeInputParameters_to_vertex( from_object=realtime_input ) ) else: realtime_input_dict = ( live_converters._LiveSendRealtimeInputParameters_to_mldev( from_object=realtime_input ) ) realtime_input_dict = _common.convert_to_dict(realtime_input_dict) realtime_input_dict = _common.encode_unserializable_types( realtime_input_dict ) await self._ws.send(json.dumps({'realtime_input': realtime_input_dict})) async def send_tool_response( self, *, function_responses: Union[ types.FunctionResponseOrDict, Sequence[types.FunctionResponseOrDict], ], ) -> None: """Send a tool response to the session. Use `send_tool_response` to reply to `LiveServerToolCall` messages from the server. To set the available tools, use the `config.tools` argument when you connect to the session (`client.live.connect`). Args: function_responses: A `FunctionResponse`-like object or list of `FunctionResponse`-like objects. Example: .. code-block:: python from google import genai from google.genai import types import os if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'): MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09' else: MODEL_NAME = 'gemini-live-2.5-flash-preview'; client = genai.Client() tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}] config = { "tools": tools, "response_modalities": ['TEXT'] } async with client.aio.live.connect( model='models/gemini-live-2.5-flash-preview', config=config ) as session: prompt = "Turn on the lights please" await session.send_client_content( turns={"parts": [{'text': prompt}]} ) async for chunk in session.receive(): if chunk.server_content: if chunk.text is not None: print(chunk.text) elif chunk.tool_call: print(chunk.tool_call) print('_'*80) function_response=types.FunctionResponse( name='turn_on_the_lights', response={'result': 'ok'}, id=chunk.tool_call.function_calls[0].id, ) print(function_response) await session.send_tool_response( function_responses=function_response ) print('_'*80) """ tool_response = t.t_tool_response(function_responses) if self._api_client.vertexai: tool_response_dict = _common.convert_to_dict( tool_response, convert_keys=True ) else: tool_response_dict = _common.convert_to_dict( tool_response, convert_keys=True ) for response in tool_response_dict.get('functionResponses', []): if response.get('id') is None: raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) await self._ws.send(json.dumps({'tool_response': tool_response_dict})) async def receive(self) -> AsyncIterator[types.LiveServerMessage]: """Receive model responses from the server. The method will yield the model responses from the server. The returned responses will represent a complete model turn. When the returned message is function call, user must call `send` with the function response to continue the turn. Yields: The model responses from the server. Example usage: .. code-block:: python client = genai.Client(api_key=API_KEY) async with client.aio.live.connect(model='...') as session: await session.send(input='Hello world!', end_of_turn=True) async for message in session.receive(): print(message) """ # TODO(b/365983264) Handle intermittent issues for the user. while result := await self._receive(): if result.server_content and result.server_content.turn_complete: yield result break yield result async def start_stream( self, *, stream: AsyncIterator[bytes], mime_type: str ) -> AsyncIterator[types.LiveServerMessage]: """[Deprecated] Start a live session from a data stream. > **Warning**: This method is deprecated and will be removed in a future version (not before Q2 2025). Please use one of the more specific methods: `send_client_content`, `send_realtime_input`, or `send_tool_response` instead. The interaction terminates when the input stream is complete. This method will start two async tasks. One task will be used to send the input stream to the model and the other task will be used to receive the responses from the model. Args: stream: An iterator that yields the model response. mime_type: The MIME type of the data in the stream. Yields: The audio bytes received from the model and server response messages. Example usage: .. code-block:: python client = genai.Client(api_key=API_KEY) config = {'response_modalities': ['AUDIO']} async def audio_stream(): stream = read_audio() for data in stream: yield data async with client.aio.live.connect(model='...', config=config) as session: for audio in session.start_stream(stream = audio_stream(), mime_type = 'audio/pcm'): play_audio_chunk(audio.data) """ warnings.warn( 'Setting `AsyncSession.start_stream` is deprecated, ' 'and will be removed in a future release (not before Q3 2025). ' 'Please use the `receive`, and `send_realtime_input`, methods instead.', DeprecationWarning, stacklevel=4, ) stop_event = asyncio.Event() # Start the send loop. When stream is complete stop_event is set. asyncio.create_task(self._send_loop(stream, mime_type, stop_event)) recv_task = None while not stop_event.is_set(): try: recv_task = asyncio.create_task(self._receive()) await asyncio.wait( [ recv_task, asyncio.create_task(stop_event.wait()), ], return_when=asyncio.FIRST_COMPLETED, ) if recv_task.done(): yield recv_task.result() # Give a chance for the send loop to process requests. await asyncio.sleep(10**-12) except ConnectionClosed: break if recv_task is not None and not recv_task.done(): recv_task.cancel() # Wait for the task to finish (cancelled or not) try: await recv_task except asyncio.CancelledError: pass async def _receive(self) -> types.LiveServerMessage: parameter_model = types.LiveServerMessage() try: raw_response = await self._ws.recv(decode=False) except TypeError: raw_response = await self._ws.recv() # type: ignore[assignment] if raw_response: try: response = json.loads(raw_response) except json.decoder.JSONDecodeError: raise ValueError(f'Failed to parse response: {raw_response!r}') else: response = {} if self._api_client.vertexai: response_dict = live_converters._LiveServerMessage_from_vertex(response) else: response_dict = response return types.LiveServerMessage._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) async def _send_loop( self, data_stream: AsyncIterator[bytes], mime_type: str, stop_event: asyncio.Event, ) -> None: async for data in data_stream: model_input = types.LiveClientRealtimeInput( media_chunks=[types.Blob(data=data, mime_type=mime_type)] ) await self.send(input=model_input) # Give a chance for the receive loop to process responses. await asyncio.sleep(10**-12) # Give a chance for the receiver to process the last response. stop_event.set() def _parse_client_message( self, input: Optional[ Union[ types.ContentListUnion, types.ContentListUnionDict, types.LiveClientContentOrDict, types.LiveClientRealtimeInputOrDict, types.LiveClientToolResponseOrDict, types.FunctionResponseOrDict, Sequence[types.FunctionResponseOrDict], ] ] = None, end_of_turn: Optional[bool] = False, ) -> types.LiveClientMessageDict: formatted_input: Any = input if not input: logging.info('No input provided. Assume it is the end of turn.') return {'client_content': {'turn_complete': True}} if isinstance(input, str): formatted_input = [input] elif isinstance(input, dict) and 'data' in input: try: blob_input = types.Blob(**input) except pydantic.ValidationError: raise ValueError( f'Unsupported input type "{type(input)}" or input content "{input}"' ) if isinstance(blob_input, types.Blob) and isinstance( blob_input.data, bytes ): formatted_input = [ blob_input.model_dump(mode='json', exclude_none=True) ] elif isinstance(input, types.Blob): formatted_input = [input] elif isinstance(input, dict) and 'name' in input and 'response' in input: # ToolResponse.FunctionResponse if not (self._api_client.vertexai) and 'id' not in input: raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) formatted_input = [input] if isinstance(formatted_input, Sequence) and any( isinstance(c, dict) and 'name' in c and 'response' in c for c in formatted_input ): # ToolResponse.FunctionResponse function_responses_input = [] for item in formatted_input: if isinstance(item, dict): try: function_response_input = types.FunctionResponse(**item) except pydantic.ValidationError: raise ValueError( f'Unsupported input type "{type(input)}" or input content' f' "{input}"' ) if ( function_response_input.id is None and not self._api_client.vertexai ): raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) else: function_response_dict = function_response_input.model_dump( exclude_none=True, mode='json' ) function_response_typeddict = types.FunctionResponseDict( name=function_response_dict.get('name'), response=function_response_dict.get('response'), ) if function_response_dict.get('id'): function_response_typeddict['id'] = function_response_dict.get( 'id' ) function_responses_input.append(function_response_typeddict) client_message = types.LiveClientMessageDict( tool_response=types.LiveClientToolResponseDict( function_responses=function_responses_input ) ) elif isinstance(formatted_input, Sequence) and any( isinstance(c, str) for c in formatted_input ): to_object: _common.StringDict = {} content_input_parts: list[types.PartUnion] = [] for item in formatted_input: if isinstance(item, get_args(types.PartUnion)): content_input_parts.append(item) if self._api_client.vertexai: contents = [ _common.convert_to_dict(item, convert_keys=True) for item in t.t_contents(content_input_parts) ] else: contents = [ _Content_to_mldev(item, to_object) for item in t.t_contents(content_input_parts) ] content_dict_list: list[types.ContentDict] = [] for item in contents: try: content_input = types.Content(**item) except pydantic.ValidationError: raise ValueError( f'Unsupported input type "{type(input)}" or input content' f' "{input}"' ) content_dict_list.append( types.ContentDict( parts=content_input.model_dump(exclude_none=True, mode='json')[ 'parts' ], role=content_input.role, ) ) client_message = types.LiveClientMessageDict( client_content=types.LiveClientContentDict( turns=content_dict_list, turn_complete=end_of_turn ) ) elif isinstance(formatted_input, Sequence): if any((isinstance(b, dict) and 'data' in b) for b in formatted_input): pass elif any(isinstance(b, types.Blob) for b in formatted_input): formatted_input = [ b.model_dump(exclude_none=True, mode='json') for b in formatted_input ] else: raise ValueError( f'Unsupported input type "{type(input)}" or input content "{input}"' ) client_message = types.LiveClientMessageDict( realtime_input=types.LiveClientRealtimeInputDict( media_chunks=formatted_input ) ) elif isinstance(formatted_input, dict): if 'content' in formatted_input or 'turns' in formatted_input: # TODO(b/365983264) Add validation checks for content_update input_dict. if 'turns' in formatted_input: content_turns = formatted_input['turns'] else: content_turns = formatted_input['content'] client_message = types.LiveClientMessageDict( client_content=types.LiveClientContentDict( turns=content_turns, turn_complete=formatted_input.get('turn_complete'), ) ) elif 'media_chunks' in formatted_input: try: realtime_input = types.LiveClientRealtimeInput(**formatted_input) except pydantic.ValidationError: raise ValueError( f'Unsupported input type "{type(input)}" or input content' f' "{input}"' ) client_message = types.LiveClientMessageDict( realtime_input=types.LiveClientRealtimeInputDict( media_chunks=realtime_input.model_dump( exclude_none=True, mode='json' )['media_chunks'] ) ) elif 'function_responses' in formatted_input: try: tool_response_input = types.LiveClientToolResponse(**formatted_input) except pydantic.ValidationError: raise ValueError( f'Unsupported input type "{type(input)}" or input content' f' "{input}"' ) client_message = types.LiveClientMessageDict( tool_response=types.LiveClientToolResponseDict( function_responses=tool_response_input.model_dump( exclude_none=True, mode='json' )['function_responses'] ) ) else: raise ValueError( f'Unsupported input type "{type(input)}" or input content "{input}"' ) elif isinstance(formatted_input, types.LiveClientRealtimeInput): realtime_input_dict = formatted_input.model_dump( exclude_none=True, mode='json' ) client_message = types.LiveClientMessageDict( realtime_input=types.LiveClientRealtimeInputDict( media_chunks=realtime_input_dict.get('media_chunks') ) ) if ( client_message['realtime_input'] is not None and client_message['realtime_input']['media_chunks'] is not None and isinstance( client_message['realtime_input']['media_chunks'][0]['data'], bytes ) ): formatted_media_chunks: list[types.BlobDict] = [] for item in client_message['realtime_input']['media_chunks']: if isinstance(item, dict): try: blob_input = types.Blob(**item) except pydantic.ValidationError: raise ValueError( f'Unsupported input type "{type(input)}" or input content' f' "{input}"' ) if ( isinstance(blob_input, types.Blob) and isinstance(blob_input.data, bytes) and blob_input.data is not None ): formatted_media_chunks.append( types.BlobDict( data=base64.b64decode(blob_input.data), mime_type=blob_input.mime_type, ) ) client_message['realtime_input'][ 'media_chunks' ] = formatted_media_chunks elif isinstance(formatted_input, types.LiveClientContent): client_content_dict = formatted_input.model_dump( exclude_none=True, mode='json' ) client_message = types.LiveClientMessageDict( client_content=types.LiveClientContentDict( turns=client_content_dict.get('turns'), turn_complete=client_content_dict.get('turn_complete'), ) ) elif isinstance(formatted_input, types.LiveClientToolResponse): # ToolResponse.FunctionResponse if ( not (self._api_client.vertexai) and formatted_input.function_responses is not None and not (formatted_input.function_responses[0].id) ): raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) client_message = types.LiveClientMessageDict( tool_response=types.LiveClientToolResponseDict( function_responses=formatted_input.model_dump( exclude_none=True, mode='json' ).get('function_responses') ) ) elif isinstance(formatted_input, types.FunctionResponse): if not (self._api_client.vertexai) and not (formatted_input.id): raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) function_response_dict = formatted_input.model_dump( exclude_none=True, mode='json' ) function_response_typeddict = types.FunctionResponseDict( name=function_response_dict.get('name'), response=function_response_dict.get('response'), ) if function_response_dict.get('id'): function_response_typeddict['id'] = function_response_dict.get('id') client_message = types.LiveClientMessageDict( tool_response=types.LiveClientToolResponseDict( function_responses=[function_response_typeddict] ) ) elif isinstance(formatted_input, Sequence) and isinstance( formatted_input[0], types.FunctionResponse ): if not (self._api_client.vertexai) and not (formatted_input[0].id): raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID) function_response_list: list[types.FunctionResponseDict] = [] for item in formatted_input: function_response_dict = item.model_dump(exclude_none=True, mode='json') function_response_typeddict = types.FunctionResponseDict( name=function_response_dict.get('name'), response=function_response_dict.get('response'), ) if function_response_dict.get('id'): function_response_typeddict['id'] = function_response_dict.get('id') function_response_list.append(function_response_typeddict) client_message = types.LiveClientMessageDict( tool_response=types.LiveClientToolResponseDict( function_responses=function_response_list ) ) else: raise ValueError( f'Unsupported input type "{type(input)}" or input content "{input}"' ) return client_message async def close(self) -> None: # Close the websocket connection. await self._ws.close() class AsyncLive(_api_module.BaseModule): """[Preview] AsyncLive.""" def __init__(self, api_client: BaseApiClient): super().__init__(api_client) self._music = AsyncLiveMusic(api_client) @property def music(self) -> AsyncLiveMusic: return self._music @contextlib.asynccontextmanager async def connect( self, *, model: str, config: Optional[types.LiveConnectConfigOrDict] = None, ) -> AsyncIterator[AsyncSession]: """[Preview] Connect to the live server. Note: the live API is currently in preview. Usage: .. code-block:: python client = genai.Client(api_key=API_KEY) config = {} async with client.aio.live.connect(model='...', config=config) as session: await session.send_client_content( turns=types.Content( role='user', parts=[types.Part(text='hello!')] ), turn_complete=True ) async for message in session.receive(): print(message) Args: model: The model to use for the live session. config: The configuration for the live session. **kwargs: additional keyword arguments. Yields: An AsyncSession object. """ # TODO(b/404946570): Support per request http options. if isinstance(config, dict): config = types.LiveConnectConfig(**config) if config and config.http_options: raise ValueError( 'google.genai.client.aio.live.connect() does not support' ' http_options at request-level in LiveConnectConfig yet. Please use' ' the client-level http_options configuration instead.' ) base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') transformed_model = t.t_model(self._api_client, model) # type: ignore parameter_model = await _t_live_connect_config(self._api_client, config) if self._api_client.api_key and not self._api_client.vertexai: version = self._api_client._http_options.api_version api_key = self._api_client.api_key method = 'BidiGenerateContent' original_headers = self._api_client._http_options.headers headers = original_headers.copy() if original_headers is not None else {} if api_key.startswith('auth_tokens/'): warnings.warn( message=( "The SDK's ephemeral token support is experimental, and may" ' change in future versions.' ), category=errors.ExperimentalWarning, ) method = 'BidiGenerateContentConstrained' headers['Authorization'] = f'Token {api_key}' if version != 'v1alpha': warnings.warn( message=( "The SDK's ephemeral token support is in v1alpha only." 'Please use client = genai.Client(api_key=token.name, ' 'http_options=types.HttpOptions(api_version="v1alpha"))' ' before session connection.' ), category=errors.ExperimentalWarning, ) uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}' request_dict = _common.convert_to_dict( live_converters._LiveConnectParameters_to_mldev( api_client=self._api_client, from_object=types.LiveConnectParameters( model=transformed_model, config=parameter_model, ).model_dump(exclude_none=True), ) ) del request_dict['config'] request_dict = _common.encode_unserializable_types(request_dict) setv(request_dict, ['setup', 'model'], transformed_model) request = json.dumps(request_dict) elif self._api_client.api_key and self._api_client.vertexai: # Headers already contains api key for express mode. api_key = self._api_client.api_key version = self._api_client._http_options.api_version uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' original_headers = self._api_client._http_options.headers headers = original_headers.copy() if original_headers is not None else {} request_dict = _common.convert_to_dict( live_converters._LiveConnectParameters_to_vertex( api_client=self._api_client, from_object=types.LiveConnectParameters( model=transformed_model, config=parameter_model, ).model_dump(exclude_none=True), ) ) del request_dict['config'] request_dict = _common.encode_unserializable_types(request_dict) setv(request_dict, ['setup', 'model'], transformed_model) request = json.dumps(request_dict) else: version = self._api_client._http_options.api_version has_sufficient_auth = ( self._api_client.project and self._api_client.location ) if self._api_client.custom_base_url and not has_sufficient_auth: # API gateway proxy can use the auth in custom headers, not url. # Enable custom url if auth is not sufficient. uri = self._api_client.custom_base_url # Keep the model as is. transformed_model = model # Do not get credentials for custom url. original_headers = self._api_client._http_options.headers headers = ( original_headers.copy() if original_headers is not None else {} ) else: uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent' if not self._api_client._credentials: # Get bearer token through Application Default Credentials. creds, _ = google.auth.default( # type: ignore scopes=['https://www.googleapis.com/auth/cloud-platform'] ) else: creds = self._api_client._credentials # creds.valid is False, and creds.token is None # Need to refresh credentials to populate those if not (creds.token and creds.valid): if requests is None: raise ValueError('The requests module is required to refresh google-auth credentials. Please install with `pip install google-auth[requests]`') auth_req = requests.Request() # type: ignore creds.refresh(auth_req) # type: ignore[no-untyped-call] bearer_token = creds.token original_headers = self._api_client._http_options.headers headers = ( original_headers.copy() if original_headers is not None else {} ) if not headers.get('Authorization'): headers['Authorization'] = f'Bearer {bearer_token}' location = self._api_client.location project = self._api_client.project if transformed_model.startswith('publishers/') and project and location: transformed_model = ( f'projects/{project}/locations/{location}/' + transformed_model ) request_dict = _common.convert_to_dict( live_converters._LiveConnectParameters_to_vertex( api_client=self._api_client, from_object=types.LiveConnectParameters( model=transformed_model, config=parameter_model, ).model_dump(exclude_none=True), ) ) del request_dict['config'] request_dict = _common.encode_unserializable_types(request_dict) if ( getv( request_dict, ['setup', 'generationConfig', 'responseModalities'] ) is None ): setv( request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'], ) request = json.dumps(request_dict) if parameter_model.tools and _mcp_utils.has_mcp_tool_usage( parameter_model.tools ): if headers is None: headers = {} _mcp_utils.set_mcp_usage_header(headers) async with ws_connect( uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx ) as ws: await ws.send(request) try: # websockets 14.0+ raw_response = await ws.recv(decode=False) except TypeError: raw_response = await ws.recv() # type: ignore[assignment] if raw_response: try: response = json.loads(raw_response) except json.decoder.JSONDecodeError: raise ValueError(f'Failed to parse response: {raw_response!r}') else: response = {} if self._api_client.vertexai: response_dict = live_converters._LiveServerMessage_from_vertex(response) else: response_dict = response setup_response = types.LiveServerMessage._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) if setup_response.setup_complete: session_id = setup_response.setup_complete.session_id else: session_id = None yield AsyncSession( api_client=self._api_client, websocket=ws, session_id=session_id, ) async def _t_live_connect_config( api_client: BaseApiClient, config: Optional[types.LiveConnectConfigOrDict], ) -> types.LiveConnectConfig: # Ensure the config is a LiveConnectConfig. if config is None: parameter_model = types.LiveConnectConfig() elif isinstance(config, dict): if getv(config, ['system_instruction']) is not None: converted_system_instruction = t.t_content( getv(config, ['system_instruction']) ) else: converted_system_instruction = None parameter_model = types.LiveConnectConfig(**config) parameter_model.system_instruction = converted_system_instruction else: if config.system_instruction is None: system_instruction = None else: system_instruction = t.t_content(getv(config, ['system_instruction'])) parameter_model = config parameter_model.system_instruction = system_instruction # Create a copy of the config model with the tools field cleared as they will # be replaced with the MCP tools converted to GenAI tools. parameter_model_copy = parameter_model.model_copy(update={'tools': None}) if parameter_model.tools: parameter_model_copy.tools = [] for tool in parameter_model.tools: if McpClientSession is not None and isinstance(tool, McpClientSession): mcp_to_genai_tool_adapter = McpToGenAiToolAdapter( tool, await tool.list_tools() ) # Extend the config with the MCP session tools converted to GenAI tools. parameter_model_copy.tools.extend(mcp_to_genai_tool_adapter.tools) elif McpTool is not None and isinstance(tool, McpTool): parameter_model_copy.tools.append(mcp_to_gemini_tool(tool)) else: parameter_model_copy.tools.append(tool) if parameter_model_copy.generation_config is not None: warnings.warn( 'Setting `LiveConnectConfig.generation_config` is deprecated, ' 'please set the fields on `LiveConnectConfig` directly. This will ' 'become an error in a future version (not before Q3 2025)', DeprecationWarning, stacklevel=4, ) return parameter_model_copy