# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """[Experimental] Live Music API client.""" import contextlib import json import logging from typing import AsyncIterator from . import _api_module from . import _common from . import _live_converters as live_converters from . import _transformers as t from . import types from ._api_client import BaseApiClient from ._common import set_value_by_path as setv try: from websockets.asyncio.client import ClientConnection from websockets.asyncio.client import connect except ModuleNotFoundError: from websockets.client import ClientConnection # type: ignore from websockets.client import connect # type: ignore logger = logging.getLogger('google_genai.live_music') class AsyncMusicSession: """[Experimental] AsyncMusicSession.""" def __init__(self, api_client: BaseApiClient, websocket: ClientConnection): self._api_client = api_client self._ws = websocket async def set_weighted_prompts( self, prompts: list[types.WeightedPrompt] ) -> None: if self._api_client.vertexai: raise NotImplementedError( 'Live music generation is not supported in Vertex AI.' ) else: client_content_dict = { 'weightedPrompts': [ _common.convert_to_dict(prompt, convert_keys=True) for prompt in prompts ] } await self._ws.send(json.dumps({'clientContent': client_content_dict})) async def set_music_generation_config( self, config: types.LiveMusicGenerationConfig ) -> None: if self._api_client.vertexai: raise NotImplementedError( 'Live music generation is not supported in Vertex AI.' ) else: config_dict = _common.convert_to_dict(config, convert_keys=True) await self._ws.send(json.dumps({'musicGenerationConfig': config_dict})) async def _send_control_signal( self, playback_control: types.LiveMusicPlaybackControl ) -> None: if self._api_client.vertexai: raise NotImplementedError( 'Live music generation is not supported in Vertex AI.' ) else: playback_control_dict = {'playbackControl': playback_control.value} await self._ws.send(json.dumps(playback_control_dict)) async def play(self) -> None: """Sends playback signal to start the music stream.""" return await self._send_control_signal(types.LiveMusicPlaybackControl.PLAY) async def pause(self) -> None: """Sends a playback signal to pause the music stream.""" return await self._send_control_signal(types.LiveMusicPlaybackControl.PAUSE) async def stop(self) -> None: """Sends a playback signal to stop the music stream. Resets the music generation context while retaining the current config. """ return await self._send_control_signal(types.LiveMusicPlaybackControl.STOP) async def reset_context(self) -> None: """Reset the context (prompts retained) without stopping the music generation.""" return await self._send_control_signal( types.LiveMusicPlaybackControl.RESET_CONTEXT ) async def receive(self) -> AsyncIterator[types.LiveMusicServerMessage]: """Receive model responses from the server. Yields: The audio chunks from the server. """ # TODO(b/365983264) Handle intermittent issues for the user. while result := await self._receive(): yield result async def _receive(self) -> types.LiveMusicServerMessage: parameter_model = types.LiveMusicServerMessage() try: raw_response = await self._ws.recv(decode=False) except TypeError: raw_response = await self._ws.recv() # type: ignore[assignment] if raw_response: try: response = json.loads(raw_response) except json.decoder.JSONDecodeError: raise ValueError(f'Failed to parse response: {raw_response!r}') else: response = {} if self._api_client.vertexai: raise NotImplementedError('Live music generation is not supported in Vertex AI.') else: response_dict = response return types.LiveMusicServerMessage._from_response( response=response_dict, kwargs=parameter_model.model_dump() ) async def close(self) -> None: """Closes the bi-directional stream and terminates the session.""" await self._ws.close() class AsyncLiveMusic(_api_module.BaseModule): """[Experimental] Live music module. Live music can be accessed via `client.aio.live.music`. """ @_common.experimental_warning( 'Realtime music generation is experimental and may change in future versions.' ) @contextlib.asynccontextmanager async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]: """[Experimental] Connect to the live music server.""" base_url = self._api_client._websocket_base_url() if isinstance(base_url, bytes): base_url = base_url.decode('utf-8') transformed_model = t.t_model(self._api_client, model) if self._api_client.api_key: api_key = self._api_client.api_key version = self._api_client._http_options.api_version uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}' headers = self._api_client._http_options.headers # Only mldev supported request_dict = _common.convert_to_dict( live_converters._LiveMusicConnectParameters_to_mldev( from_object=types.LiveMusicConnectParameters( model=transformed_model, ).model_dump(exclude_none=True) ) ) setv(request_dict, ['setup', 'model'], transformed_model) request = json.dumps(request_dict) else: raise NotImplementedError('Live music generation is not supported in Vertex AI.') try: async with connect(uri, additional_headers=headers) as ws: await ws.send(request) logger.info(await ws.recv(decode=False)) yield AsyncMusicSession(api_client=self._api_client, websocket=ws) except TypeError: # Try with the older websockets API async with connect(uri, extra_headers=headers) as ws: await ws.send(request) logger.info(await ws.recv()) yield AsyncMusicSession(api_client=self._api_client, websocket=ws)