167 lines
5.4 KiB
Python
167 lines
5.4 KiB
Python
"""Base agent with LLM calling via litellm."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import litellm
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Disable litellm telemetry
|
|
litellm.telemetry = False
|
|
|
|
|
|
class BaseAgent:
|
|
"""Base class for all pipeline agents."""
|
|
|
|
name: str = "base"
|
|
description: str = ""
|
|
system_prompt: str = ""
|
|
model: str = "" # empty = use default from config
|
|
|
|
def __init__(self, model: str | None = None):
|
|
if model:
|
|
self.model = model
|
|
|
|
def get_model(self) -> str:
|
|
return self.model or settings.llm_model
|
|
|
|
async def call_llm(
|
|
self,
|
|
prompt: str,
|
|
*,
|
|
system: str | None = None,
|
|
temperature: float = 0.3,
|
|
max_tokens: int = 4096,
|
|
response_format: dict | None = None,
|
|
) -> str:
|
|
"""Call LLM via litellm. Returns the text response."""
|
|
messages = []
|
|
sys_prompt = system or self.system_prompt
|
|
if sys_prompt:
|
|
messages.append({"role": "system", "content": sys_prompt})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": self.get_model(),
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens,
|
|
}
|
|
if settings.llm_api_key:
|
|
kwargs["api_key"] = settings.llm_api_key
|
|
if settings.llm_api_base:
|
|
kwargs["api_base"] = settings.llm_api_base
|
|
if response_format:
|
|
kwargs["response_format"] = response_format
|
|
|
|
logger.info(f"[{self.name}] calling {self.get_model()}")
|
|
response = await litellm.acompletion(**kwargs)
|
|
content = response.choices[0].message.content
|
|
logger.info(f"[{self.name}] got {len(content)} chars")
|
|
return content
|
|
|
|
async def call_llm_json(self, prompt: str, **kwargs) -> dict:
|
|
"""Call LLM and parse response as JSON."""
|
|
raw = await self.call_llm(
|
|
prompt,
|
|
response_format={"type": "json_object"},
|
|
**kwargs,
|
|
)
|
|
# Strip markdown code fences if present
|
|
text = raw.strip()
|
|
if text.startswith("```"):
|
|
first_nl = text.find("\n")
|
|
if first_nl != -1:
|
|
text = text[first_nl + 1:]
|
|
if text.endswith("```"):
|
|
text = text[: text.rfind("```")]
|
|
text = text.strip()
|
|
|
|
# Sanitize control characters inside JSON string values
|
|
# (models sometimes emit literal newlines/tabs inside strings)
|
|
import re
|
|
def _clean_json_string(s: str) -> str:
|
|
# Replace unescaped control chars within JSON strings
|
|
# This is a best-effort fix for common model outputs
|
|
result = []
|
|
in_string = False
|
|
escape = False
|
|
for ch in s:
|
|
if escape:
|
|
result.append(ch)
|
|
escape = False
|
|
continue
|
|
if ch == '\\':
|
|
result.append(ch)
|
|
escape = True
|
|
continue
|
|
if ch == '"':
|
|
in_string = not in_string
|
|
result.append(ch)
|
|
continue
|
|
if in_string and ord(ch) < 32:
|
|
# Replace control chars with escaped versions
|
|
if ch == '\n':
|
|
result.append('\\n')
|
|
elif ch == '\r':
|
|
result.append('\\r')
|
|
elif ch == '\t':
|
|
result.append('\\t')
|
|
else:
|
|
result.append(f'\\u{ord(ch):04x}')
|
|
continue
|
|
result.append(ch)
|
|
return ''.join(result)
|
|
|
|
# Try parsing with multiple strategies
|
|
for attempt, candidate in enumerate([text, _clean_json_string(text)]):
|
|
try:
|
|
return json.loads(candidate)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Last resort: try to extract the largest valid JSON object
|
|
# (model may have appended commentary after the JSON)
|
|
brace_depth = 0
|
|
start = text.find('{')
|
|
if start == -1:
|
|
raise json.JSONDecodeError("No JSON object found", text, 0)
|
|
|
|
cleaned = _clean_json_string(text)
|
|
for i, ch in enumerate(cleaned[start:], start):
|
|
if ch == '{':
|
|
brace_depth += 1
|
|
elif ch == '}':
|
|
brace_depth -= 1
|
|
if brace_depth == 0:
|
|
try:
|
|
return json.loads(cleaned[start:i + 1])
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# If all else fails, use json_repair library or raise
|
|
try:
|
|
import json_repair
|
|
return json_repair.loads(text)
|
|
except (ImportError, Exception):
|
|
raise json.JSONDecodeError(
|
|
f"Failed to parse JSON after multiple attempts", text, 0
|
|
)
|
|
|
|
async def run(self, context: dict[str, Any]) -> dict[str, Any]:
|
|
"""Execute this agent's task. Override in subclasses.
|
|
|
|
Args:
|
|
context: Shared pipeline context (accumulated by previous agents).
|
|
|
|
Returns:
|
|
Dict of new keys to merge into context.
|
|
"""
|
|
raise NotImplementedError
|