Files
20260327-c863ce53/app/agents/base.py
2026-04-25 19:25:22 +08:00

167 lines
5.4 KiB
Python

"""Base agent with LLM calling via litellm."""
from __future__ import annotations
import json
import logging
from typing import Any
import litellm
from app.config import settings
logger = logging.getLogger(__name__)
# Disable litellm telemetry
litellm.telemetry = False
class BaseAgent:
"""Base class for all pipeline agents."""
name: str = "base"
description: str = ""
system_prompt: str = ""
model: str = "" # empty = use default from config
def __init__(self, model: str | None = None):
if model:
self.model = model
def get_model(self) -> str:
return self.model or settings.llm_model
async def call_llm(
self,
prompt: str,
*,
system: str | None = None,
temperature: float = 0.3,
max_tokens: int = 4096,
response_format: dict | None = None,
) -> str:
"""Call LLM via litellm. Returns the text response."""
messages = []
sys_prompt = system or self.system_prompt
if sys_prompt:
messages.append({"role": "system", "content": sys_prompt})
messages.append({"role": "user", "content": prompt})
kwargs: dict[str, Any] = {
"model": self.get_model(),
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if settings.llm_api_key:
kwargs["api_key"] = settings.llm_api_key
if settings.llm_api_base:
kwargs["api_base"] = settings.llm_api_base
if response_format:
kwargs["response_format"] = response_format
logger.info(f"[{self.name}] calling {self.get_model()}")
response = await litellm.acompletion(**kwargs)
content = response.choices[0].message.content
logger.info(f"[{self.name}] got {len(content)} chars")
return content
async def call_llm_json(self, prompt: str, **kwargs) -> dict:
"""Call LLM and parse response as JSON."""
raw = await self.call_llm(
prompt,
response_format={"type": "json_object"},
**kwargs,
)
# Strip markdown code fences if present
text = raw.strip()
if text.startswith("```"):
first_nl = text.find("\n")
if first_nl != -1:
text = text[first_nl + 1:]
if text.endswith("```"):
text = text[: text.rfind("```")]
text = text.strip()
# Sanitize control characters inside JSON string values
# (models sometimes emit literal newlines/tabs inside strings)
import re
def _clean_json_string(s: str) -> str:
# Replace unescaped control chars within JSON strings
# This is a best-effort fix for common model outputs
result = []
in_string = False
escape = False
for ch in s:
if escape:
result.append(ch)
escape = False
continue
if ch == '\\':
result.append(ch)
escape = True
continue
if ch == '"':
in_string = not in_string
result.append(ch)
continue
if in_string and ord(ch) < 32:
# Replace control chars with escaped versions
if ch == '\n':
result.append('\\n')
elif ch == '\r':
result.append('\\r')
elif ch == '\t':
result.append('\\t')
else:
result.append(f'\\u{ord(ch):04x}')
continue
result.append(ch)
return ''.join(result)
# Try parsing with multiple strategies
for attempt, candidate in enumerate([text, _clean_json_string(text)]):
try:
return json.loads(candidate)
except json.JSONDecodeError:
continue
# Last resort: try to extract the largest valid JSON object
# (model may have appended commentary after the JSON)
brace_depth = 0
start = text.find('{')
if start == -1:
raise json.JSONDecodeError("No JSON object found", text, 0)
cleaned = _clean_json_string(text)
for i, ch in enumerate(cleaned[start:], start):
if ch == '{':
brace_depth += 1
elif ch == '}':
brace_depth -= 1
if brace_depth == 0:
try:
return json.loads(cleaned[start:i + 1])
except json.JSONDecodeError:
continue
# If all else fails, use json_repair library or raise
try:
import json_repair
return json_repair.loads(text)
except (ImportError, Exception):
raise json.JSONDecodeError(
f"Failed to parse JSON after multiple attempts", text, 0
)
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
"""Execute this agent's task. Override in subclasses.
Args:
context: Shared pipeline context (accumulated by previous agents).
Returns:
Dict of new keys to merge into context.
"""
raise NotImplementedError