# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Common utilities for the SDK.""" import base64 import collections.abc import datetime import enum import functools import logging import re import typing from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin import uuid import warnings import pydantic from pydantic import alias_generators from typing_extensions import TypeAlias logger = logging.getLogger('google_genai._common') StringDict: TypeAlias = dict[str, Any] class ExperimentalWarning(Warning): """Warning for experimental features.""" def set_value_by_path( data: Optional[dict[Any, Any]], keys: list[str], value: Any ) -> None: """Examples: set_value_by_path({}, ['a', 'b'], v) -> {'a': {'b': v}} set_value_by_path({}, ['a', 'b[]', c], [v1, v2]) -> {'a': {'b': [{'c': v1}, {'c': v2}]}} set_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'd'], v3) -> {'a': {'b': [{'c': v1, 'd': v3}, {'c': v2, 'd': v3}]}} """ if value is None: return for i, key in enumerate(keys[:-1]): if key.endswith('[]'): key_name = key[:-2] if data is not None and key_name not in data: if isinstance(value, list): data[key_name] = [{} for _ in range(len(value))] else: raise ValueError( f'value {value} must be a list given an array path {key}' ) if isinstance(value, list) and data is not None: for j, d in enumerate(data[key_name]): set_value_by_path(d, keys[i + 1 :], value[j]) else: if data is not None: for d in data[key_name]: set_value_by_path(d, keys[i + 1 :], value) return elif key.endswith('[0]'): key_name = key[:-3] if data is not None and key_name not in data: data[key_name] = [{}] if data is not None: set_value_by_path(data[key_name][0], keys[i + 1 :], value) return if data is not None: data = data.setdefault(key, {}) if data is not None: existing_data = data.get(keys[-1]) # If there is an existing value, merge, not overwrite. if existing_data is not None: # Don't overwrite existing non-empty value with new empty value. # This is triggered when handling tuning datasets. if not value: pass # Don't fail when overwriting value with same value elif value == existing_data: pass # Instead of overwriting dictionary with another dictionary, merge them. # This is important for handling training and validation datasets in tuning. elif isinstance(existing_data, dict) and isinstance(value, dict): # Merging dictionaries. Consider deep merging in the future. existing_data.update(value) else: raise ValueError( f'Cannot set value for an existing key. Key: {keys[-1]};' f' Existing value: {existing_data}; New value: {value}.' ) else: if ( keys[-1] == '_self' and isinstance(data, dict) and isinstance(value, dict) ): data.update(value) else: data[keys[-1]] = value def get_value_by_path( data: Any, keys: list[str], *, default_value: Any = None ) -> Any: """Examples: get_value_by_path({'a': {'b': v}}, ['a', 'b']) -> v get_value_by_path({'a': {'b': [{'c': v1}, {'c': v2}]}}, ['a', 'b[]', 'c']) -> [v1, v2] """ if keys == ['_self']: return data for i, key in enumerate(keys): if not data: return default_value if key.endswith('[]'): key_name = key[:-2] if key_name in data: return [ get_value_by_path(d, keys[i + 1 :], default_value=default_value) for d in data[key_name] ] else: return default_value elif key.endswith('[0]'): key_name = key[:-3] if key_name in data and data[key_name]: return get_value_by_path( data[key_name][0], keys[i + 1 :], default_value=default_value ) else: return default_value else: if key in data: data = data[key] elif isinstance(data, BaseModel) and hasattr(data, key): data = getattr(data, key) else: return default_value return data def move_value_by_path(data: Any, paths: dict[str, str]) -> None: """Moves values from source paths to destination paths. Examples: move_value_by_path( {'requests': [{'content': v1}, {'content': v2}]}, {'requests[].*': 'requests[].request.*'} ) -> {'requests': [{'request': {'content': v1}}, {'request': {'content': v2}}]} """ for source_path, dest_path in paths.items(): source_keys = source_path.split('.') dest_keys = dest_path.split('.') # Determine keys to exclude from wildcard to avoid cyclic references exclude_keys = set() wildcard_idx = -1 for i, key in enumerate(source_keys): if key == '*': wildcard_idx = i break if wildcard_idx != -1 and len(dest_keys) > wildcard_idx: # Extract the intermediate key between source and dest paths # Example: source=['requests[]', '*'], dest=['requests[]', 'request', '*'] # We want to exclude 'request' for i in range(wildcard_idx, len(dest_keys)): key = dest_keys[i] if key != '*' and not key.endswith('[]') and not key.endswith('[0]'): exclude_keys.add(key) # Move values recursively _move_value_recursive(data, source_keys, dest_keys, 0, exclude_keys) def _move_value_recursive( data: Any, source_keys: list[str], dest_keys: list[str], key_idx: int, exclude_keys: set[str], ) -> None: """Recursively moves values from source path to destination path.""" if key_idx >= len(source_keys): return key = source_keys[key_idx] if key.endswith('[]'): # Handle array iteration key_name = key[:-2] if key_name in data and isinstance(data[key_name], list): for item in data[key_name]: _move_value_recursive( item, source_keys, dest_keys, key_idx + 1, exclude_keys ) elif key == '*': # Handle wildcard - move all fields if isinstance(data, dict): # Get all keys to move (excluding specified keys) keys_to_move = [ k for k in list(data.keys()) if not k.startswith('_') and k not in exclude_keys ] # Collect values to move values_to_move = {k: data[k] for k in keys_to_move} # Set values at destination for k, v in values_to_move.items(): # Build destination keys with the field name new_dest_keys = [] for dk in dest_keys[key_idx:]: if dk == '*': new_dest_keys.append(k) else: new_dest_keys.append(dk) set_value_by_path(data, new_dest_keys, v) # Delete from source for k in keys_to_move: del data[k] else: # Navigate to next level if key in data: _move_value_recursive( data[key], source_keys, dest_keys, key_idx + 1, exclude_keys ) def maybe_snake_to_camel(snake_str: str, convert: bool = True) -> str: """Converts a snake_case string to CamelCase, if convert is True.""" if not convert: return snake_str return re.sub(r'_([a-zA-Z])', lambda match: match.group(1).upper(), snake_str) def convert_to_dict(obj: object, convert_keys: bool = False) -> Any: """Recursively converts a given object to a dictionary. If the object is a Pydantic model, it uses the model's `model_dump()` method. Args: obj: The object to convert. convert_keys: Whether to convert the keys from snake case to camel case. Returns: A dictionary representation of the object, a list of objects if a list is passed, or the object itself if it is not a dictionary, list, or Pydantic model. """ if isinstance(obj, pydantic.BaseModel): return convert_to_dict(obj.model_dump(exclude_none=True), convert_keys) elif isinstance(obj, dict): return { maybe_snake_to_camel(key, convert_keys): convert_to_dict(value) for key, value in obj.items() } elif isinstance(obj, list): return [convert_to_dict(item, convert_keys) for item in obj] else: return obj def _is_struct_type(annotation: type) -> bool: """Checks if the given annotation is list[dict[str, typing.Any]] or typing.List[typing.Dict[str, typing.Any]]. This maps to Struct type in the API. """ outer_origin = get_origin(annotation) outer_args = get_args(annotation) if outer_origin is not list: # Python 3.9+ normalizes list return False if not outer_args or len(outer_args) != 1: return False inner_annotation = outer_args[0] inner_origin = get_origin(inner_annotation) inner_args = get_args(inner_annotation) if inner_origin is not dict: # Python 3.9+ normalizes to dict return False if not inner_args or len(inner_args) != 2: # dict should have exactly two type arguments return False # Check if the dict arguments are str and typing.Any key_type, value_type = inner_args return key_type is str and value_type is typing.Any def _remove_extra_fields(model: Any, response: dict[str, object]) -> None: """Removes extra fields from the response that are not in the model. Mutates the response in place. """ key_values = list(response.items()) for key, value in key_values: # Need to convert to snake case to match model fields names # ex: UsageMetadata alias_map = { field_info.alias: key for key, field_info in model.model_fields.items() } if key not in model.model_fields and key not in alias_map: response.pop(key) continue key = alias_map.get(key, key) annotation = model.model_fields[key].annotation # Get the BaseModel if Optional if typing.get_origin(annotation) is Union: annotation = typing.get_args(annotation)[0] # if dict, assume BaseModel but also check that field type is not dict # example: FunctionCall.args if isinstance(value, dict) and typing.get_origin(annotation) is not dict: _remove_extra_fields(annotation, value) elif isinstance(value, list): if _is_struct_type(annotation): continue for item in value: # assume a list of dict is list of BaseModel if isinstance(item, dict): _remove_extra_fields(typing.get_args(annotation)[0], item) T = typing.TypeVar('T', bound='BaseModel') def _pretty_repr( obj: Any, *, indent_level: int = 0, indent_delta: int = 2, max_len: int = 100, max_items: int = 5, depth: int = 6, visited: Optional[FrozenSet[int]] = None, ) -> str: """Returns a representation of the given object.""" if visited is None: visited = frozenset() obj_id = id(obj) if obj_id in visited: return '<... Circular reference ...>' if depth < 0: return '<... Max depth ...>' visited = frozenset(list(visited) + [obj_id]) indent = ' ' * indent_level next_indent_str = ' ' * (indent_level + indent_delta) if isinstance(obj, pydantic.BaseModel): cls_name = obj.__class__.__name__ items = [] # Sort fields for consistent output fields = sorted(type(obj).model_fields) for field_name in fields: field_info = type(obj).model_fields[field_name] if not field_info.repr: # Respect Field(repr=False) continue try: value = getattr(obj, field_name) except AttributeError: continue if value is None: continue value_repr = _pretty_repr( value, indent_level=indent_level + indent_delta, indent_delta=indent_delta, max_len=max_len, max_items=max_items, depth=depth - 1, visited=visited, ) items.append(f'{next_indent_str}{field_name}={value_repr}') if not items: return f'{cls_name}()' return f'{cls_name}(\n' + ',\n'.join(items) + f'\n{indent})' elif isinstance(obj, str): if '\n' in obj: escaped = obj.replace('"""', '\\"\\"\\"') # Indent the multi-line string block contents return f'"""{escaped}"""' return repr(obj) elif isinstance(obj, bytes): if len(obj) > max_len: return f"{repr(obj[:max_len-3])[:-1]}...'" return repr(obj) elif isinstance(obj, collections.abc.Mapping): if not obj: return '{}' # Check if the next level of recursion for keys/values will exceed the depth limit. if depth <= 0: item_count_str = f"{len(obj)} item{'s' if len(obj) != 1 else ''}" return f'{{<... {item_count_str} at Max depth ...>}}' if len(obj) > max_items: return f'' items = [] try: sorted_keys = sorted(obj.keys(), key=str) except TypeError: sorted_keys = list(obj.keys()) for k in sorted_keys: v = obj[k] k_repr = _pretty_repr( k, indent_level=indent_level + indent_delta, indent_delta=indent_delta, max_len=max_len, max_items=max_items, depth=depth - 1, visited=visited, ) v_repr = _pretty_repr( v, indent_level=indent_level + indent_delta, indent_delta=indent_delta, max_len=max_len, max_items=max_items, depth=depth - 1, visited=visited, ) items.append(f'{next_indent_str}{k_repr}: {v_repr}') return f'{{\n' + ',\n'.join(items) + f'\n{indent}}}' elif isinstance(obj, (list, tuple, set)): return _format_collection( obj, indent_level=indent_level, indent_delta=indent_delta, max_len=max_len, max_items=max_items, depth=depth, visited=visited, ) else: # Fallback to standard repr, indenting subsequent lines only raw_repr = repr(obj) # Replace newlines with newline + indent return raw_repr.replace('\n', f'\n{next_indent_str}') def _format_collection( obj: Any, *, indent_level: int, indent_delta: int, max_len: int, max_items: int, depth: int, visited: FrozenSet[int], ) -> str: """Formats a collection (list, tuple, set).""" if isinstance(obj, list): brackets = ('[', ']') internal_obj = obj elif isinstance(obj, tuple): brackets = ('(', ')') internal_obj = list(obj) elif isinstance(obj, set): internal_obj = list(obj) if obj: brackets = ('{', '}') else: brackets = ('set(', ')') else: raise ValueError(f'Unsupported collection type: {type(obj)}') if not internal_obj: return brackets[0] + brackets[1] # If the call to _pretty_repr for elements will have depth < 0 if depth <= 0: item_count_str = f"{len(internal_obj)} item{'s'*(len(internal_obj)!=1)}" return f'{brackets[0]}<... {item_count_str} at Max depth ...>{brackets[1]}' indent = ' ' * indent_level next_indent_str = ' ' * (indent_level + indent_delta) elements = [] num_to_show = min(len(internal_obj), max_items) for i in range(num_to_show): elem = internal_obj[i] elements.append( next_indent_str + _pretty_repr( elem, indent_level=indent_level + indent_delta, indent_delta=indent_delta, max_len=max_len, max_items=max_items, depth=depth - 1, visited=visited, ) ) if len(internal_obj) > max_items: elements.append( f'{next_indent_str}<... {len(internal_obj) - max_items} more items ...>' ) return f'{brackets[0]}\n' + ',\n'.join(elements) + f',\n{indent}{brackets[1]}' class BaseModel(pydantic.BaseModel): model_config = pydantic.ConfigDict( alias_generator=alias_generators.to_camel, populate_by_name=True, from_attributes=True, protected_namespaces=(), extra='forbid', # This allows us to use arbitrary types in the model. E.g. PIL.Image. arbitrary_types_allowed=True, ser_json_bytes='base64', val_json_bytes='base64', ignored_types=(typing.TypeVar,), ) @pydantic.model_validator(mode='before') @classmethod def _check_field_type_mismatches(cls, data: Any) -> Any: """Check for type mismatches and warn before Pydantic processes the data.""" # Handle both dict and Pydantic model inputs if not isinstance(data, (dict, pydantic.BaseModel)): return data for field_name, field_info in cls.model_fields.items(): if isinstance(data, dict): value = data.get(field_name) else: value = getattr(data, field_name, None) if value is None: continue expected_type = field_info.annotation origin = get_origin(expected_type) if origin is Union: args = get_args(expected_type) non_none_types = [arg for arg in args if arg is not type(None)] if len(non_none_types) == 1: expected_type = non_none_types[0] if (isinstance(expected_type, type) and get_origin(expected_type) is None and issubclass(expected_type, pydantic.BaseModel) and isinstance(value, pydantic.BaseModel) and not isinstance(value, expected_type)): logger.warning( f"Type mismatch in {cls.__name__}.{field_name}: " f"expected {expected_type.__name__}, got {type(value).__name__}" ) return data def __repr__(self) -> str: try: return _pretty_repr(self) except Exception: return super().__repr__() @classmethod def _from_response( cls: typing.Type[T], *, response: dict[str, object], kwargs: dict[str, object], ) -> T: # To maintain forward compatibility, we need to remove extra fields from # the response. # We will provide another mechanism to allow users to access these fields. # For Agent Engine we don't want to call _remove_all_fields because the # user may pass a dict that is not a subclass of BaseModel. # If more modules require we skip this, we may want a different approach should_skip_removing_fields = ( kwargs is not None and 'config' in kwargs and kwargs['config'] is not None and isinstance(kwargs['config'], dict) and 'include_all_fields' in kwargs['config'] and kwargs['config']['include_all_fields'] ) if not should_skip_removing_fields: _remove_extra_fields(cls, response) validated_response = cls.model_validate(response) return validated_response def to_json_dict(self) -> dict[str, object]: return self.model_dump(exclude_none=True, mode='json') class CaseInSensitiveEnum(str, enum.Enum): """Case insensitive enum.""" @classmethod def _missing_(cls, value: Any) -> Any: try: return cls[value.upper()] # Try to access directly with uppercase except KeyError: try: return cls[value.lower()] # Try to access directly with lowercase except KeyError: warnings.warn(f'{value} is not a valid {cls.__name__}') try: # Creating a enum instance based on the value # We need to use super() to avoid infinite recursion. unknown_enum_val = super().__new__(cls, value) unknown_enum_val._name_ = str(value) # pylint: disable=protected-access unknown_enum_val._value_ = value # pylint: disable=protected-access return unknown_enum_val except: return None def timestamped_unique_name() -> str: """Composes a timestamped unique name. Returns: A string representing a unique name. """ timestamp = datetime.datetime.now().strftime('%Y%m%d%H%M%S') unique_id = uuid.uuid4().hex[0:5] return f'{timestamp}_{unique_id}' def encode_unserializable_types(data: dict[str, object]) -> dict[str, object]: """Converts unserializable types in dict to json.dumps() compatible types. This function is called in models.py after calling convert_to_dict(). The convert_to_dict() can convert pydantic object to dict. However, the input to convert_to_dict() is dict mixed of pydantic object and nested dict(the output of converters). So they may be bytes in the dict and they are out of `ser_json_bytes` control in model_dump(mode='json') called in `convert_to_dict`, as well as datetime deserialization in Pydantic json mode. Returns: A dictionary with json.dumps() incompatible type (e.g. bytes datetime) to compatible type (e.g. base64 encoded string, isoformat date string). """ processed_data: dict[str, object] = {} if not isinstance(data, dict): return data for key, value in data.items(): if isinstance(value, bytes): processed_data[key] = base64.urlsafe_b64encode(value).decode('ascii') elif isinstance(value, datetime.datetime): processed_data[key] = value.isoformat() elif isinstance(value, dict): processed_data[key] = encode_unserializable_types(value) elif isinstance(value, list): if all(isinstance(v, bytes) for v in value): processed_data[key] = [ base64.urlsafe_b64encode(v).decode('ascii') for v in value ] if all(isinstance(v, datetime.datetime) for v in value): processed_data[key] = [v.isoformat() for v in value] else: processed_data[key] = [encode_unserializable_types(v) for v in value] else: processed_data[key] = value return processed_data def experimental_warning( message: str, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Experimental warning, only warns once.""" def decorator(func: Callable[..., Any]) -> Callable[..., Any]: warning_done = False @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: nonlocal warning_done if not warning_done: warning_done = True warnings.warn( message=message, category=ExperimentalWarning, stacklevel=2, ) return func(*args, **kwargs) return wrapper return decorator def _normalize_key_for_matching(key_str: str) -> str: """Normalizes a key for case-insensitive and snake/camel matching.""" return key_str.replace('_', '').lower() def align_key_case( target_dict: StringDict, update_dict: StringDict ) -> StringDict: """Aligns the keys of update_dict to the case of target_dict keys. Args: target_dict: The dictionary with the target key casing. update_dict: The dictionary whose keys need to be aligned. Returns: A new dictionary with keys aligned to target_dict's key casing. """ aligned_update_dict: StringDict = {} target_keys_map = { _normalize_key_for_matching(key): key for key in target_dict.keys() } for key, value in update_dict.items(): normalized_update_key = _normalize_key_for_matching(key) if normalized_update_key in target_keys_map: aligned_key = target_keys_map[normalized_update_key] else: aligned_key = key if isinstance(value, dict) and isinstance( target_dict.get(aligned_key), dict ): aligned_update_dict[aligned_key] = align_key_case( target_dict[aligned_key], value ) elif isinstance(value, list) and isinstance( target_dict.get(aligned_key), list ): # Direct assign as we treat update_dict list values as golden source. aligned_update_dict[aligned_key] = value else: aligned_update_dict[aligned_key] = value return aligned_update_dict def recursive_dict_update( target_dict: StringDict, update_dict: StringDict ) -> None: """Recursively updates a target dictionary with values from an update dictionary. We don't enforce the updated dict values to have the same type with the target_dict values except log warnings. Users providing the update_dict should be responsible for constructing correct data. Args: target_dict (dict): The dictionary to be updated. update_dict (dict): The dictionary containing updates. """ # Python SDK http request may change in camel case or snake case: # If the field is directly set via setv() function, then it is camel case; # otherwise it is snake case. # Align the update_dict key case to target_dict to ensure correct dict update. aligned_update_dict = align_key_case(target_dict, update_dict) for key, value in aligned_update_dict.items(): if ( key in target_dict and isinstance(target_dict[key], dict) and isinstance(value, dict) ): recursive_dict_update(target_dict[key], value) elif key in target_dict and not isinstance(target_dict[key], type(value)): logger.warning( f"Type mismatch for key '{key}'. Existing type:" f' {type(target_dict[key])}, new type: {type(value)}. Overwriting.' ) target_dict[key] = value else: target_dict[key] = value