"""Base agent with LLM calling via litellm.""" from __future__ import annotations import json import logging from typing import Any import litellm from app.config import settings logger = logging.getLogger(__name__) # Disable litellm telemetry litellm.telemetry = False class BaseAgent: """Base class for all pipeline agents.""" name: str = "base" description: str = "" system_prompt: str = "" model: str = "" # empty = use default from config def __init__(self, model: str | None = None): if model: self.model = model def get_model(self) -> str: return self.model or settings.llm_model async def call_llm( self, prompt: str, *, system: str | None = None, temperature: float = 0.3, max_tokens: int = 4096, response_format: dict | None = None, ) -> str: """Call LLM via litellm. Returns the text response.""" messages = [] sys_prompt = system or self.system_prompt if sys_prompt: messages.append({"role": "system", "content": sys_prompt}) messages.append({"role": "user", "content": prompt}) kwargs: dict[str, Any] = { "model": self.get_model(), "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } if settings.llm_api_key: kwargs["api_key"] = settings.llm_api_key if settings.llm_api_base: kwargs["api_base"] = settings.llm_api_base if response_format: kwargs["response_format"] = response_format logger.info(f"[{self.name}] calling {self.get_model()}") response = await litellm.acompletion(**kwargs) content = response.choices[0].message.content logger.info(f"[{self.name}] got {len(content)} chars") return content async def call_llm_json(self, prompt: str, **kwargs) -> dict: """Call LLM and parse response as JSON.""" raw = await self.call_llm( prompt, response_format={"type": "json_object"}, **kwargs, ) # Strip markdown code fences if present text = raw.strip() if text.startswith("```"): first_nl = text.find("\n") if first_nl != -1: text = text[first_nl + 1:] if text.endswith("```"): text = text[: text.rfind("```")] text = text.strip() # Sanitize control characters inside JSON string values # (models sometimes emit literal newlines/tabs inside strings) import re def _clean_json_string(s: str) -> str: # Replace unescaped control chars within JSON strings # This is a best-effort fix for common model outputs result = [] in_string = False escape = False for ch in s: if escape: result.append(ch) escape = False continue if ch == '\\': result.append(ch) escape = True continue if ch == '"': in_string = not in_string result.append(ch) continue if in_string and ord(ch) < 32: # Replace control chars with escaped versions if ch == '\n': result.append('\\n') elif ch == '\r': result.append('\\r') elif ch == '\t': result.append('\\t') else: result.append(f'\\u{ord(ch):04x}') continue result.append(ch) return ''.join(result) # Try parsing with multiple strategies for attempt, candidate in enumerate([text, _clean_json_string(text)]): try: return json.loads(candidate) except json.JSONDecodeError: continue # Last resort: try to extract the largest valid JSON object # (model may have appended commentary after the JSON) brace_depth = 0 start = text.find('{') if start == -1: raise json.JSONDecodeError("No JSON object found", text, 0) cleaned = _clean_json_string(text) for i, ch in enumerate(cleaned[start:], start): if ch == '{': brace_depth += 1 elif ch == '}': brace_depth -= 1 if brace_depth == 0: try: return json.loads(cleaned[start:i + 1]) except json.JSONDecodeError: continue # If all else fails, use json_repair library or raise try: import json_repair return json_repair.loads(text) except (ImportError, Exception): raise json.JSONDecodeError( f"Failed to parse JSON after multiple attempts", text, 0 ) async def run(self, context: dict[str, Any]) -> dict[str, Any]: """Execute this agent's task. Override in subclasses. Args: context: Shared pipeline context (accumulated by previous agents). Returns: Dict of new keys to merge into context. """ raise NotImplementedError