312 lines
9.9 KiB
Python
312 lines
9.9 KiB
Python
|
|
"""Test script for Agent SDK implementation.
|
||
|
|
|
||
|
|
This script tests the Agent SDK integration without running the full bot.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
|
||
|
|
# Ensure we're testing the Agent SDK mode
|
||
|
|
os.environ["USE_AGENT_SDK"] = "true"
|
||
|
|
os.environ["USE_DIRECT_API"] = "false"
|
||
|
|
os.environ["USE_CLAUDE_CODE_SERVER"] = "false"
|
||
|
|
|
||
|
|
def test_llm_interface_initialization():
|
||
|
|
"""Test 1: LLMInterface initialization with Agent SDK."""
|
||
|
|
print("\n=== Test 1: LLMInterface Initialization ===")
|
||
|
|
try:
|
||
|
|
from llm_interface import LLMInterface
|
||
|
|
|
||
|
|
llm = LLMInterface(provider="claude")
|
||
|
|
|
||
|
|
print(f"✓ LLMInterface created successfully")
|
||
|
|
print(f" - Provider: {llm.provider}")
|
||
|
|
print(f" - Mode: {llm.mode}")
|
||
|
|
print(f" - Model: {llm.model}")
|
||
|
|
print(f" - Agent SDK available: {llm.agent_sdk is not None}")
|
||
|
|
|
||
|
|
if llm.mode != "agent_sdk":
|
||
|
|
print(f"✗ WARNING: Expected mode 'agent_sdk', got '{llm.mode}'")
|
||
|
|
if llm.mode == "direct_api":
|
||
|
|
print(" - This likely means claude-agent-sdk is not installed")
|
||
|
|
print(" - Run: pip install claude-agent-sdk")
|
||
|
|
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
print(f"✗ Test failed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def test_simple_chat():
|
||
|
|
"""Test 2: Simple chat without tools."""
|
||
|
|
print("\n=== Test 2: Simple Chat (No Tools) ===")
|
||
|
|
try:
|
||
|
|
from llm_interface import LLMInterface
|
||
|
|
|
||
|
|
llm = LLMInterface(provider="claude")
|
||
|
|
|
||
|
|
if llm.mode != "agent_sdk":
|
||
|
|
print(f"⊘ Skipping test (mode is '{llm.mode}', not 'agent_sdk')")
|
||
|
|
return False
|
||
|
|
|
||
|
|
print("Sending simple chat message...")
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "Say 'Hello from Agent SDK!' in exactly those words."}
|
||
|
|
]
|
||
|
|
|
||
|
|
response = llm.chat(messages, system="You are a helpful assistant.", max_tokens=100)
|
||
|
|
|
||
|
|
print(f"✓ Chat completed successfully")
|
||
|
|
print(f" - Response: {response[:100]}...")
|
||
|
|
print(f" - Response type: {type(response)}")
|
||
|
|
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
print(f"✗ Test failed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def test_chat_with_tools():
|
||
|
|
"""Test 3: Chat with tools (message format compatibility)."""
|
||
|
|
print("\n=== Test 3: Chat with Tools ===")
|
||
|
|
try:
|
||
|
|
from llm_interface import LLMInterface
|
||
|
|
from tools import TOOL_DEFINITIONS
|
||
|
|
|
||
|
|
llm = LLMInterface(provider="claude")
|
||
|
|
|
||
|
|
if llm.mode != "agent_sdk":
|
||
|
|
print(f"⊘ Skipping test (mode is '{llm.mode}', not 'agent_sdk')")
|
||
|
|
return False
|
||
|
|
|
||
|
|
print("Sending chat message with tool definitions...")
|
||
|
|
messages = [
|
||
|
|
{"role": "user", "content": "What is 2+2? Just respond with the number, don't use any tools."}
|
||
|
|
]
|
||
|
|
|
||
|
|
response = llm.chat_with_tools(
|
||
|
|
messages,
|
||
|
|
tools=TOOL_DEFINITIONS,
|
||
|
|
system="You are a helpful assistant.",
|
||
|
|
max_tokens=100
|
||
|
|
)
|
||
|
|
|
||
|
|
print(f"✓ Chat with tools completed successfully")
|
||
|
|
print(f" - Response type: {type(response)}")
|
||
|
|
print(f" - Has .content: {hasattr(response, 'content')}")
|
||
|
|
print(f" - Has .stop_reason: {hasattr(response, 'stop_reason')}")
|
||
|
|
print(f" - Has .usage: {hasattr(response, 'usage')}")
|
||
|
|
print(f" - Stop reason: {response.stop_reason}")
|
||
|
|
|
||
|
|
if hasattr(response, 'content') and response.content:
|
||
|
|
print(f" - Content blocks: {len(response.content)}")
|
||
|
|
for i, block in enumerate(response.content):
|
||
|
|
print(f" - Block {i}: {type(block).__name__}")
|
||
|
|
if hasattr(block, 'type'):
|
||
|
|
print(f" - Type: {block.type}")
|
||
|
|
if hasattr(block, 'text'):
|
||
|
|
print(f" - Text: {block.text[:50]}...")
|
||
|
|
|
||
|
|
return True
|
||
|
|
except Exception as e:
|
||
|
|
print(f"✗ Test failed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def test_response_format_compatibility():
|
||
|
|
"""Test 4: Verify response format matches what agent.py expects."""
|
||
|
|
print("\n=== Test 4: Response Format Compatibility ===")
|
||
|
|
try:
|
||
|
|
from llm_interface import LLMInterface
|
||
|
|
from anthropic.types import TextBlock, ToolUseBlock
|
||
|
|
|
||
|
|
llm = LLMInterface(provider="claude")
|
||
|
|
|
||
|
|
if llm.mode != "agent_sdk":
|
||
|
|
print(f"⊘ Skipping test (mode is '{llm.mode}', not 'agent_sdk')")
|
||
|
|
return False
|
||
|
|
|
||
|
|
# Simulate SDK response
|
||
|
|
mock_sdk_response = {
|
||
|
|
"content": [
|
||
|
|
{"type": "text", "text": "Test response"}
|
||
|
|
],
|
||
|
|
"stop_reason": "end_turn",
|
||
|
|
"usage": {
|
||
|
|
"input_tokens": 10,
|
||
|
|
"output_tokens": 5
|
||
|
|
},
|
||
|
|
"id": "test_message_id",
|
||
|
|
"model": "claude-haiku-4-5-20251001"
|
||
|
|
}
|
||
|
|
|
||
|
|
print("Converting mock SDK response to Message format...")
|
||
|
|
message = llm._convert_sdk_response_to_message(mock_sdk_response)
|
||
|
|
|
||
|
|
print(f"✓ Conversion successful")
|
||
|
|
print(f" - Message type: {type(message).__name__}")
|
||
|
|
print(f" - Has content: {hasattr(message, 'content')}")
|
||
|
|
print(f" - Has stop_reason: {hasattr(message, 'stop_reason')}")
|
||
|
|
print(f" - Has usage: {hasattr(message, 'usage')}")
|
||
|
|
print(f" - Content[0] type: {type(message.content[0]).__name__}")
|
||
|
|
print(f" - Content[0].type: {message.content[0].type}")
|
||
|
|
print(f" - Content[0].text: {message.content[0].text}")
|
||
|
|
print(f" - Stop reason: {message.stop_reason}")
|
||
|
|
print(f" - Usage.input_tokens: {message.usage.input_tokens}")
|
||
|
|
print(f" - Usage.output_tokens: {message.usage.output_tokens}")
|
||
|
|
|
||
|
|
# Verify all required attributes exist
|
||
|
|
required_attrs = ['content', 'stop_reason', 'usage', 'id', 'model', 'role', 'type']
|
||
|
|
missing_attrs = [attr for attr in required_attrs if not hasattr(message, attr)]
|
||
|
|
|
||
|
|
if missing_attrs:
|
||
|
|
print(f"✗ Missing attributes: {missing_attrs}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
print(f"✓ All required attributes present")
|
||
|
|
return True
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"✗ Test failed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def test_mode_selection():
|
||
|
|
"""Test 5: Verify mode selection logic."""
|
||
|
|
print("\n=== Test 5: Mode Selection Logic ===")
|
||
|
|
|
||
|
|
test_cases = [
|
||
|
|
{
|
||
|
|
"name": "Default (Agent SDK)",
|
||
|
|
"env": {},
|
||
|
|
"expected": "agent_sdk"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"name": "Explicit Direct API",
|
||
|
|
"env": {"USE_DIRECT_API": "true"},
|
||
|
|
"expected": "direct_api"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"name": "Legacy Server",
|
||
|
|
"env": {"USE_CLAUDE_CODE_SERVER": "true"},
|
||
|
|
"expected": "legacy_server"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"name": "Priority: Direct API > Agent SDK",
|
||
|
|
"env": {"USE_DIRECT_API": "true", "USE_AGENT_SDK": "true"},
|
||
|
|
"expected": "direct_api"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"name": "Priority: Legacy > Agent SDK",
|
||
|
|
"env": {"USE_CLAUDE_CODE_SERVER": "true", "USE_AGENT_SDK": "true"},
|
||
|
|
"expected": "legacy_server"
|
||
|
|
}
|
||
|
|
]
|
||
|
|
|
||
|
|
all_passed = True
|
||
|
|
|
||
|
|
for test_case in test_cases:
|
||
|
|
print(f"\n Testing: {test_case['name']}")
|
||
|
|
|
||
|
|
# Save current env
|
||
|
|
old_env = {}
|
||
|
|
for key in ["USE_DIRECT_API", "USE_CLAUDE_CODE_SERVER", "USE_AGENT_SDK"]:
|
||
|
|
old_env[key] = os.environ.get(key)
|
||
|
|
|
||
|
|
# Set test env
|
||
|
|
for key in old_env.keys():
|
||
|
|
if key in os.environ:
|
||
|
|
del os.environ[key]
|
||
|
|
for key, value in test_case["env"].items():
|
||
|
|
os.environ[key] = value
|
||
|
|
|
||
|
|
# Force reimport to pick up new env vars
|
||
|
|
if 'llm_interface' in sys.modules:
|
||
|
|
del sys.modules['llm_interface']
|
||
|
|
|
||
|
|
try:
|
||
|
|
from llm_interface import LLMInterface
|
||
|
|
llm = LLMInterface(provider="claude")
|
||
|
|
|
||
|
|
if llm.mode == test_case["expected"]:
|
||
|
|
print(f" ✓ Correct mode: {llm.mode}")
|
||
|
|
else:
|
||
|
|
print(f" ✗ Wrong mode: expected '{test_case['expected']}', got '{llm.mode}'")
|
||
|
|
all_passed = False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ✗ Error: {e}")
|
||
|
|
all_passed = False
|
||
|
|
|
||
|
|
# Restore env
|
||
|
|
for key in old_env.keys():
|
||
|
|
if key in os.environ:
|
||
|
|
del os.environ[key]
|
||
|
|
if old_env[key] is not None:
|
||
|
|
os.environ[key] = old_env[key]
|
||
|
|
|
||
|
|
# Force reimport one more time to reset
|
||
|
|
if 'llm_interface' in sys.modules:
|
||
|
|
del sys.modules['llm_interface']
|
||
|
|
|
||
|
|
return all_passed
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
"""Run all tests."""
|
||
|
|
print("=" * 70)
|
||
|
|
print("AGENT SDK IMPLEMENTATION TEST SUITE")
|
||
|
|
print("=" * 70)
|
||
|
|
|
||
|
|
tests = [
|
||
|
|
("Initialization", test_llm_interface_initialization),
|
||
|
|
("Simple Chat", test_simple_chat),
|
||
|
|
("Chat with Tools", test_chat_with_tools),
|
||
|
|
("Response Format", test_response_format_compatibility),
|
||
|
|
("Mode Selection", test_mode_selection),
|
||
|
|
]
|
||
|
|
|
||
|
|
results = {}
|
||
|
|
|
||
|
|
for name, test_func in tests:
|
||
|
|
try:
|
||
|
|
results[name] = test_func()
|
||
|
|
except Exception as e:
|
||
|
|
print(f"\n✗ Test '{name}' crashed: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
results[name] = False
|
||
|
|
|
||
|
|
# Summary
|
||
|
|
print("\n" + "=" * 70)
|
||
|
|
print("TEST SUMMARY")
|
||
|
|
print("=" * 70)
|
||
|
|
|
||
|
|
for name, passed in results.items():
|
||
|
|
status = "✓ PASS" if passed else "✗ FAIL"
|
||
|
|
print(f"{status:8} {name}")
|
||
|
|
|
||
|
|
passed_count = sum(1 for p in results.values() if p)
|
||
|
|
total_count = len(results)
|
||
|
|
|
||
|
|
print(f"\nTotal: {passed_count}/{total_count} tests passed")
|
||
|
|
|
||
|
|
if passed_count == total_count:
|
||
|
|
print("\n🎉 All tests passed!")
|
||
|
|
return 0
|
||
|
|
else:
|
||
|
|
print(f"\n⚠ {total_count - passed_count} test(s) failed")
|
||
|
|
return 1
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|