Files
ajarbot/llm_interface.py

422 lines
16 KiB
Python
Raw Normal View History

"""LLM Interface - Claude API, GLM, and other models.
Supports three modes for Claude:
1. Agent SDK (uses Pro subscription) - DEFAULT - Set USE_AGENT_SDK=true (default)
2. Direct API (pay-per-token) - Set USE_DIRECT_API=true
3. Legacy: Local Claude Code server - Set USE_CLAUDE_CODE_SERVER=true (deprecated)
"""
import os
from typing import Any, Dict, List, Optional
import requests
from anthropic import Anthropic
from anthropic.types import Message, ContentBlock, TextBlock, ToolUseBlock, Usage
from usage_tracker import UsageTracker
# Try to import Agent SDK (optional dependency)
try:
from claude_agent_sdk import AgentSDK
import anyio
AGENT_SDK_AVAILABLE = True
except ImportError:
AGENT_SDK_AVAILABLE = False
# API key environment variable names by provider
_API_KEY_ENV_VARS = {
"claude": "ANTHROPIC_API_KEY",
"glm": "GLM_API_KEY",
}
# Mode selection (priority order: USE_DIRECT_API > USE_CLAUDE_CODE_SERVER > default to Agent SDK)
_USE_DIRECT_API = os.getenv("USE_DIRECT_API", "false").lower() == "true"
_CLAUDE_CODE_SERVER_URL = os.getenv("CLAUDE_CODE_SERVER_URL", "http://localhost:8000")
_USE_CLAUDE_CODE_SERVER = os.getenv("USE_CLAUDE_CODE_SERVER", "false").lower() == "true"
# Agent SDK is the default if available and no other mode is explicitly enabled
_USE_AGENT_SDK = os.getenv("USE_AGENT_SDK", "true").lower() == "true"
# Default models by provider
_DEFAULT_MODELS = {
"claude": "claude-haiku-4-5-20251001", # For Direct API (pay-per-token)
"claude_agent_sdk": "claude-sonnet-4-5-20250929", # For Agent SDK (flat-rate subscription)
"glm": "glm-4-plus",
}
_GLM_BASE_URL = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
class LLMInterface:
"""Simple LLM interface supporting Claude and GLM."""
def __init__(
self,
provider: str = "claude",
api_key: Optional[str] = None,
track_usage: bool = True,
) -> None:
self.provider = provider
self.api_key = api_key or os.getenv(
_API_KEY_ENV_VARS.get(provider, ""),
)
self.client: Optional[Anthropic] = None
self.agent_sdk: Optional[Any] = None
# Model will be set after determining mode
# Determine mode (priority: direct API > legacy server > agent SDK)
if provider == "claude":
if _USE_DIRECT_API:
self.mode = "direct_api"
elif _USE_CLAUDE_CODE_SERVER:
self.mode = "legacy_server"
elif _USE_AGENT_SDK and AGENT_SDK_AVAILABLE:
self.mode = "agent_sdk"
else:
# Fallback to direct API if Agent SDK not available
self.mode = "direct_api"
if _USE_AGENT_SDK and not AGENT_SDK_AVAILABLE:
print("[LLM] Warning: Agent SDK not available, falling back to Direct API")
print("[LLM] Install with: pip install claude-agent-sdk")
else:
self.mode = "direct_api" # Non-Claude providers use direct API
# Usage tracking (disabled when using Agent SDK or legacy server)
self.tracker = UsageTracker() if (track_usage and self.mode == "direct_api") else None
# Set model based on mode
if provider == "claude":
if self.mode == "agent_sdk":
self.model = _DEFAULT_MODELS.get("claude_agent_sdk", "claude-sonnet-4-5-20250929")
else:
self.model = _DEFAULT_MODELS.get(provider, "claude-haiku-4-5-20251001")
else:
self.model = _DEFAULT_MODELS.get(provider, "")
# Initialize based on mode
if provider == "claude":
if self.mode == "agent_sdk":
print(f"[LLM] Using Claude Agent SDK (flat-rate subscription) with model: {self.model}")
self.agent_sdk = AgentSDK()
elif self.mode == "direct_api":
print(f"[LLM] Using Direct API (pay-per-token) with model: {self.model}")
self.client = Anthropic(api_key=self.api_key)
elif self.mode == "legacy_server":
print(f"[LLM] Using Claude Code server at {_CLAUDE_CODE_SERVER_URL} (Pro subscription) with model: {self.model}")
# Verify server is running
try:
response = requests.get(f"{_CLAUDE_CODE_SERVER_URL}/", timeout=2)
response.raise_for_status()
print(f"[LLM] Claude Code server is running: {response.json()}")
except Exception as e:
print(f"[LLM] Warning: Could not connect to Claude Code server: {e}")
print(f"[LLM] Note: Claude Code server mode is deprecated. Using Agent SDK instead.")
def chat(
self,
messages: List[Dict],
system: Optional[str] = None,
max_tokens: int = 4096,
) -> str:
"""Send chat request and get response.
Raises:
Exception: If the API call fails or returns an unexpected response.
"""
if self.provider == "claude":
# Agent SDK mode (Pro subscription)
if self.mode == "agent_sdk":
try:
# Use anyio to bridge async SDK to sync interface
response = anyio.from_thread.run(
self._agent_sdk_chat,
messages,
system,
max_tokens
)
return response
except Exception as e:
raise Exception(f"Agent SDK error: {e}")
# Legacy Claude Code server (Pro subscription)
elif self.mode == "legacy_server":
try:
payload = {
"messages": [{"role": m["role"], "content": m["content"]} for m in messages],
"system": system,
"max_tokens": max_tokens
}
response = requests.post(
f"{_CLAUDE_CODE_SERVER_URL}/v1/chat",
json=payload,
timeout=120
)
response.raise_for_status()
data = response.json()
return data.get("content", "")
except Exception as e:
raise Exception(f"Claude Code server error: {e}")
# Direct API (pay-per-token)
elif self.mode == "direct_api":
response = self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
system=system or "",
messages=messages,
)
# Track usage
if self.tracker and hasattr(response, "usage"):
self.tracker.track(
model=self.model,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
cache_creation_tokens=getattr(
response.usage, "cache_creation_input_tokens", 0
),
cache_read_tokens=getattr(
response.usage, "cache_read_input_tokens", 0
),
)
if not response.content:
return ""
return response.content[0].text
if self.provider == "glm":
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": system or ""},
] + messages,
"max_tokens": max_tokens,
}
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(
_GLM_BASE_URL, json=payload, headers=headers,
timeout=60,
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
raise ValueError(f"Unsupported provider: {self.provider}")
async def _agent_sdk_chat(
self,
messages: List[Dict],
system: Optional[str],
max_tokens: int
) -> str:
"""Internal async method for Agent SDK chat (called via anyio bridge)."""
response = await self.agent_sdk.chat(
messages=messages,
system=system,
max_tokens=max_tokens,
model=self.model
)
# Extract text from response
if isinstance(response, dict):
return response.get("content", "")
return str(response)
async def _agent_sdk_chat_with_tools(
self,
messages: List[Dict],
tools: List[Dict[str, Any]],
system: Optional[str],
max_tokens: int
) -> Message:
"""Internal async method for Agent SDK chat with tools (called via anyio bridge)."""
response = await self.agent_sdk.chat(
messages=messages,
tools=tools,
system=system,
max_tokens=max_tokens,
model=self.model
)
# Convert Agent SDK response to anthropic.types.Message format
return self._convert_sdk_response_to_message(response)
def _convert_sdk_response_to_message(self, sdk_response: Dict[str, Any]) -> Message:
"""Convert Agent SDK response to anthropic.types.Message format.
This ensures compatibility with agent.py's existing tool loop.
"""
# Extract content blocks
content_blocks = []
raw_content = sdk_response.get("content", [])
if isinstance(raw_content, str):
# Simple text response
content_blocks = [TextBlock(type="text", text=raw_content)]
elif isinstance(raw_content, list):
# List of content blocks
for block in raw_content:
if isinstance(block, dict):
if block.get("type") == "text":
content_blocks.append(TextBlock(
type="text",
text=block.get("text", "")
))
elif block.get("type") == "tool_use":
content_blocks.append(ToolUseBlock(
type="tool_use",
id=block.get("id", ""),
name=block.get("name", ""),
input=block.get("input", {})
))
# Extract usage information
usage_data = sdk_response.get("usage", {})
usage = Usage(
input_tokens=usage_data.get("input_tokens", 0),
output_tokens=usage_data.get("output_tokens", 0)
)
# Create Message object
# Note: We create a minimal Message-compatible object
# The Message class from anthropic.types is read-only, so we create a mock
# Capture self.model before defining inner class
model_name = sdk_response.get("model", self.model)
class MessageLike:
def __init__(self, content, stop_reason, usage, model):
self.content = content
self.stop_reason = stop_reason
self.usage = usage
self.id = sdk_response.get("id", "sdk_message")
self.model = model
self.role = "assistant"
self.type = "message"
return MessageLike(
content=content_blocks,
stop_reason=sdk_response.get("stop_reason", "end_turn"),
usage=usage,
model=model_name
)
def chat_with_tools(
self,
messages: List[Dict],
tools: List[Dict[str, Any]],
system: Optional[str] = None,
max_tokens: int = 4096,
use_cache: bool = False,
) -> Message:
"""Send chat request with tool support. Returns full Message object.
Args:
use_cache: Enable prompt caching for Sonnet models (saves 90% on repeated context)
"""
if self.provider != "claude":
raise ValueError("Tool use only supported for Claude provider")
# Agent SDK mode (Pro subscription)
if self.mode == "agent_sdk":
try:
# Use anyio to bridge async SDK to sync interface
response = anyio.from_thread.run(
self._agent_sdk_chat_with_tools,
messages,
tools,
system,
max_tokens
)
return response
except Exception as e:
raise Exception(f"Agent SDK error: {e}")
# Legacy Claude Code server (Pro subscription)
elif self.mode == "legacy_server":
try:
payload = {
"messages": messages,
"tools": tools,
"system": system,
"max_tokens": max_tokens
}
response = requests.post(
f"{_CLAUDE_CODE_SERVER_URL}/v1/chat/tools",
json=payload,
timeout=120
)
response.raise_for_status()
# Convert response to Message-like object
data = response.json()
# Create a mock Message object with the response
class MockMessage:
def __init__(self, data):
self.content = data.get("content", [])
self.stop_reason = data.get("stop_reason", "end_turn")
self.usage = type('obj', (object,), {
'input_tokens': data.get("usage", {}).get("input_tokens", 0),
'output_tokens': data.get("usage", {}).get("output_tokens", 0)
})
return MockMessage(data)
except Exception as e:
raise Exception(f"Claude Code server error: {e}")
# Direct API (pay-per-token)
elif self.mode == "direct_api":
# Enable caching only for Sonnet models (not worth it for Haiku)
enable_caching = use_cache and "sonnet" in self.model.lower()
# Structure system prompt for optimal caching
if enable_caching and system:
# Convert string to list format with cache control
system_blocks = [
{
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"}
}
]
else:
system_blocks = system or ""
response = self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
system=system_blocks,
messages=messages,
tools=tools,
)
# Track usage
if self.tracker and hasattr(response, "usage"):
self.tracker.track(
model=self.model,
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
cache_creation_tokens=getattr(
response.usage, "cache_creation_input_tokens", 0
),
cache_read_tokens=getattr(
response.usage, "cache_read_input_tokens", 0
),
)
return response
def set_model(self, model: str) -> None:
"""Change the active model."""
self.model = model
def get_usage_stats(self, target_date: Optional[str] = None) -> Dict:
"""Get usage statistics and costs.
Args:
target_date: Date string (YYYY-MM-DD). If None, returns today's stats.
Returns:
Dict with cost, token counts, and breakdown by model.
"""
if not self.tracker:
return {"error": "Usage tracking not enabled"}
return self.tracker.get_daily_cost(target_date)