try:
import anthropic as _ # noqa: F401
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Anthropic support requires the optional dependency 'anthropic'. "
'Install with `pip install "kosong[contrib]"`.'
) from exc
import copy
import json
from collections.abc import AsyncIterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, Self, TypedDict, Unpack, cast
from anthropic import (
AnthropicError,
AsyncAnthropic,
AsyncStream,
omit,
)
from anthropic import (
APIConnectionError as AnthropicAPIConnectionError,
)
from anthropic import (
APIStatusError as AnthropicAPIStatusError,
)
from anthropic import (
APITimeoutError as AnthropicAPITimeoutError,
)
from anthropic import (
AuthenticationError as AnthropicAuthenticationError,
)
from anthropic import (
PermissionDeniedError as AnthropicPermissionDeniedError,
)
from anthropic import (
RateLimitError as AnthropicRateLimitError,
)
from anthropic.lib.streaming import MessageStopEvent
from anthropic.types import (
Base64ImageSourceParam,
CacheControlEphemeralParam,
ContentBlockParam,
ImageBlockParam,
MessageDeltaEvent,
MessageDeltaUsage,
MessageParam,
MessageStartEvent,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawMessageStreamEvent,
TextBlockParam,
ThinkingBlockParam,
ThinkingConfigParam,
ToolChoiceParam,
ToolParam,
ToolResultBlockParam,
ToolUseBlockParam,
URLImageSourceParam,
Usage,
)
from anthropic.types import (
Message as AnthropicMessage,
)
from anthropic.types.tool_result_block_param import Content as ToolResultContent
from kosong.chat_provider import (
APIConnectionError,
APIStatusError,
APITimeoutError,
ChatProvider,
ChatProviderError,
StreamedMessagePart,
ThinkingEffort,
TokenUsage,
)
from kosong.contrib.chat_provider.common import ToolMessageConversion
from kosong.message import (
ContentPart,
ImageURLPart,
Message,
TextPart,
ThinkPart,
ToolCall,
ToolCallPart,
)
from kosong.tooling import Tool
if TYPE_CHECKING:
def type_check(anthropic: "Anthropic"):
_: ChatProvider = anthropic
type MessagePayload = tuple[str | None, list[MessageParam]]
type BetaFeatures = Literal["interleaved-thinking-2025-05-14"]
class Anthropic:
"""
Chat provider backed by Anthropic's Messages API.
"""
name = "anthropic"
class GenerationKwargs(TypedDict, total=False):
max_tokens: int | None
temperature: float | None
top_k: int | None
top_p: float | None
# e.g., {"type": "enabled", "budget_tokens": 1024}
thinking: ThinkingConfigParam | None
# e.g., {"type": "auto", "disable_parallel_tool_use": True}
tool_choice: ToolChoiceParam | None
beta_features: list[BetaFeatures] | None
extra_headers: Mapping[str, str] | None
def __init__(
self,
*,
model: str,
api_key: str | None = None,
base_url: str | None = None,
stream: bool = True,
# which process should we apply on tool result
tool_message_conversion: ToolMessageConversion | None = None,
# Must provide a max_tokens. Can be overridden by .with_generation_kwargs()
default_max_tokens: int,
**client_kwargs: Any,
):
self._model = model
self._stream = stream
self._client = AsyncAnthropic(api_key=api_key, base_url=base_url, **client_kwargs)
self._tool_message_conversion: ToolMessageConversion | None = tool_message_conversion
self._generation_kwargs: Anthropic.GenerationKwargs = {
"max_tokens": default_max_tokens,
"beta_features": ["interleaved-thinking-2025-05-14"],
}
@property
def model_name(self) -> str:
return self._model
@property
def thinking_effort(self) -> "ThinkingEffort | None":
thinking_config = self._generation_kwargs.get("thinking")
if thinking_config is None:
return None
if thinking_config["type"] == "disabled":
return "off"
budget = thinking_config["budget_tokens"]
if budget <= 1024:
return "low"
if budget <= 4096:
return "medium"
return "high"
async def generate(
self,
system_prompt: str,
tools: Sequence[Tool],
history: Sequence[Message],
) -> "AnthropicStreamedMessage":
# https://docs.claude.com/en/api/messages#body-messages
# Anthropic API does not support system roles, but just a system prompt.
system = (
[
TextBlockParam(
text=system_prompt,
type="text",
cache_control=CacheControlEphemeralParam(type="ephemeral"),
)
]
if system_prompt
else omit
)
messages: list[MessageParam] = []
for message in history:
messages.append(self._convert_message(message))
if messages:
last_message = messages[-1]
last_content = last_message["content"]
# inject cache control in the last content.
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching
if isinstance(last_content, list) and last_content:
content_blocks = cast(list[ContentBlockParam], last_content)
last_block = content_blocks[-1]
match last_block["type"]:
case (
"text"
| "image"
| "document"
| "search_result"
| "tool_use"
| "tool_result"
| "server_tool_use"
| "web_search_tool_result"
):
last_block["cache_control"] = CacheControlEphemeralParam(type="ephemeral")
case "thinking" | "redacted_thinking":
pass
generation_kwargs: dict[str, Any] = {}
generation_kwargs.update(self._generation_kwargs)
betas = generation_kwargs.pop("beta_features", [])
extra_headers = {
**{"anthropic-beta": ",".join(str(e) for e in betas)},
**(generation_kwargs.pop("extra_headers", {})),
}
tools_ = [_convert_tool(tool) for tool in tools]
if tools:
tools_[-1]["cache_control"] = CacheControlEphemeralParam(type="ephemeral")
try:
response = await self._client.messages.create(
model=self._model,
messages=messages,
system=system,
tools=tools_,
stream=self._stream,
extra_headers=extra_headers,
**generation_kwargs,
)
return AnthropicStreamedMessage(response)
except AnthropicError as e:
raise _convert_error(e) from e
def with_thinking(self, effort: "ThinkingEffort") -> Self:
# XXX: this is a heuristic mapping based on suggestions given by Claude
thinking_config: ThinkingConfigParam
match effort:
case "off":
thinking_config = {"type": "disabled"}
case "low":
thinking_config = {"type": "enabled", "budget_tokens": 1024}
case "medium":
thinking_config = {"type": "enabled", "budget_tokens": 4096}
case "high":
thinking_config = {"type": "enabled", "budget_tokens": 32_000}
return self.with_generation_kwargs(thinking=thinking_config)
def with_generation_kwargs(self, **kwargs: Unpack[GenerationKwargs]) -> Self:
"""
Copy the chat provider, updating the generation kwargs with the given values.
Returns:
Self: A new instance of the chat provider with updated generation kwargs.
"""
new_self = copy.copy(self)
new_self._generation_kwargs = copy.deepcopy(self._generation_kwargs)
new_self._generation_kwargs.update(kwargs)
return new_self
@property
def model_parameters(self) -> dict[str, Any]:
"""
The parameters of the model to use.
For tracing/logging purposes.
"""
model_parameters: dict[str, Any] = {"base_url": str(self._client.base_url)}
model_parameters.update(self._generation_kwargs)
return model_parameters
def _convert_message(self, message: Message) -> MessageParam:
"""Convert a single internal message into Anthropic wire format."""
role = message.role
if role == "system":
# Anthropic does not support system messages in the conversation.
# We map it to a special user message.
return MessageParam(
role="user",
content=[
TextBlockParam(
type="text", text=f"{message.extract_text(sep='\n')}"
)
],
)
elif role == "tool":
if message.tool_call_id is None:
raise ChatProviderError("Tool message missing `tool_call_id`.")
if self._tool_message_conversion == "extract_text":
content = message.extract_text(sep="\n")
else:
content = message.content
block = _tool_result_message_to_block(message.tool_call_id, content)
return MessageParam(role="user", content=[block])
assert role in ("user", "assistant")
blocks: list[ContentBlockParam] = []
for part in message.content:
if isinstance(part, TextPart):
blocks.append(TextBlockParam(type="text", text=part.text))
elif isinstance(part, ImageURLPart):
blocks.append(_image_url_part_to_anthropic(part))
elif isinstance(part, ThinkPart):
if part.encrypted is None:
# missing signature, strip this thinking block.
continue
else:
blocks.append(
ThinkingBlockParam(
type="thinking", thinking=part.think, signature=part.encrypted
)
)
else:
continue
for tool_call in message.tool_calls or []:
if tool_call.function.arguments:
try:
parsed_arguments = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as exc: # pragma: no cover - defensive guard
raise ChatProviderError("Tool call arguments must be valid JSON.") from exc
if not isinstance(parsed_arguments, dict):
raise ChatProviderError("Tool call arguments must be a JSON object.")
tool_input = cast(dict[str, object], parsed_arguments)
else:
tool_input = {}
blocks.append(
ToolUseBlockParam(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=tool_input,
)
)
return MessageParam(role=role, content=blocks)
class AnthropicStreamedMessage:
def __init__(self, response: AnthropicMessage | AsyncStream[RawMessageStreamEvent]):
if isinstance(response, AnthropicMessage):
self._iter = self._convert_non_stream_response(response)
else:
self._iter = self._convert_stream_response(response)
self._id: str | None = None
self._usage = Usage(input_tokens=0, output_tokens=0)
def __aiter__(self) -> AsyncIterator[StreamedMessagePart]:
return self
async def __anext__(self) -> StreamedMessagePart:
return await self._iter.__anext__()
@property
def id(self) -> str | None:
return self._id
@property
def usage(self) -> TokenUsage | None:
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
return TokenUsage(
# Note: in some Anthropic-compatible APIs, input_tokens can be None
input_other=self._usage.input_tokens or 0,
output=self._usage.output_tokens,
input_cache_read=self._usage.cache_read_input_tokens or 0,
input_cache_creation=self._usage.cache_creation_input_tokens or 0,
)
def _update_usage(self, delta_usage: MessageDeltaUsage) -> None:
if delta_usage.cache_creation_input_tokens is not None:
self._usage.cache_creation_input_tokens = delta_usage.cache_creation_input_tokens
if delta_usage.cache_read_input_tokens is not None:
self._usage.cache_read_input_tokens = delta_usage.cache_read_input_tokens
if delta_usage.input_tokens is not None:
self._usage.input_tokens = delta_usage.input_tokens
if delta_usage.output_tokens is not None: # type: ignore
self._usage.output_tokens = delta_usage.output_tokens
async def _convert_non_stream_response(
self,
response: AnthropicMessage,
) -> AsyncIterator[StreamedMessagePart]:
self._id = response.id
self._usage = response.usage
for block in response.content:
match block.type:
case "text":
yield TextPart(text=block.text)
case "thinking":
yield ThinkPart(think=block.thinking, encrypted=block.signature)
case "redacted_thinking":
yield ThinkPart(think="", encrypted=block.data)
case "tool_use":
yield ToolCall(
id=block.id,
function=ToolCall.FunctionBody(
name=block.name, arguments=json.dumps(block.input)
),
)
case _:
continue
async def _convert_stream_response(
self,
manager: AsyncStream[RawMessageStreamEvent],
) -> AsyncIterator[StreamedMessagePart]:
try:
async with manager as stream:
async for event in stream:
if isinstance(event, MessageStartEvent):
self._id = event.message.id
# Capture initial usage from start event
# (contains initial prompt/input token usage)
self._usage = event.message.usage
elif isinstance(event, RawContentBlockStartEvent):
block = event.content_block
match block.type:
case "text":
yield TextPart(text=block.text)
case "thinking":
yield ThinkPart(think=block.thinking)
case "redacted_thinking":
yield ThinkPart(think="", encrypted=block.data)
case "tool_use":
yield ToolCall(
id=block.id,
function=ToolCall.FunctionBody(name=block.name, arguments=""),
)
case "server_tool_use" | "web_search_tool_result":
# ignore
continue
elif isinstance(event, RawContentBlockDeltaEvent):
delta = event.delta
match delta.type:
case "text_delta":
yield TextPart(text=delta.text)
case "thinking_delta":
yield ThinkPart(think=delta.thinking)
case "input_json_delta":
yield ToolCallPart(arguments_part=delta.partial_json)
case "signature_delta":
yield ThinkPart(think="", encrypted=delta.signature)
case "citations_delta":
# ignore
continue
elif isinstance(event, MessageDeltaEvent):
if event.usage:
self._update_usage(event.usage)
elif isinstance(event, MessageStopEvent):
continue
except AnthropicError as exc:
raise _convert_error(exc) from exc
def _convert_tool(tool: Tool) -> ToolParam:
return {
"name": tool.name,
"description": tool.description,
"input_schema": tool.parameters,
}
def _tool_result_message_to_block(
tool_call_id: str, content: str | list[ContentPart]
) -> ToolResultBlockParam:
block_content: str | list[ToolResultContent]
# If tool_result_process is `extract_text`, we join all text parts into one string
if isinstance(content, str):
block_content = content
else:
# Otherwise, map parts to content blocks
blocks: list[ToolResultContent] = []
for part in content:
if isinstance(part, TextPart):
if part.text:
blocks.append(TextBlockParam(type="text", text=part.text))
elif isinstance(part, ImageURLPart):
blocks.append(_image_url_part_to_anthropic(part))
else:
# https://docs.claude.com/en/docs/build-with-claude/files#file-types-and-content-blocks
# Anthropic API supports very limited file types
raise ChatProviderError(
f"Anthropic API does not support {type(part)} in tool result"
)
block_content = blocks
return ToolResultBlockParam(
type="tool_result",
tool_use_id=tool_call_id,
content=block_content,
)
def _image_url_part_to_anthropic(part: ImageURLPart) -> ImageBlockParam:
url = part.image_url.url
# data:[][;base64],
if url.startswith("data:"):
res = url[5:].split(";base64,", 1)
if len(res) != 2:
raise ChatProviderError(f"Invalid data URL for image: {url}")
media_type, data = res
if media_type not in ("image/png", "image/jpeg", "image/gif", "image/webp"):
raise ChatProviderError(
f"Unsupported media type for base64 image: {media_type}, url: {url}"
)
return ImageBlockParam(
type="image",
source=Base64ImageSourceParam(
type="base64",
data=data,
media_type=media_type,
),
)
else:
return ImageBlockParam(
type="image",
source=URLImageSourceParam(type="url", url=url),
)
def _convert_error(error: AnthropicError) -> ChatProviderError:
if isinstance(error, AnthropicAPIStatusError):
return APIStatusError(error.status_code, str(error))
if isinstance(error, AnthropicAuthenticationError):
return APIStatusError(getattr(error, "status_code", 401), str(error))
if isinstance(error, AnthropicPermissionDeniedError):
return APIStatusError(getattr(error, "status_code", 403), str(error))
if isinstance(error, AnthropicRateLimitError):
return APIStatusError(getattr(error, "status_code", 429), str(error))
if isinstance(error, AnthropicAPIConnectionError):
return APIConnectionError(str(error))
if isinstance(error, AnthropicAPITimeoutError):
return APITimeoutError(str(error))
return ChatProviderError(f"Anthropic error: {error}")