auto-save 2026-04-01 09:03 (+8, ~2)
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
3
src/capture/__init__.py
Normal file
3
src/capture/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .adb_capture import ADBCapture
|
||||
|
||||
__all__ = ["ADBCapture"]
|
||||
118
src/capture/adb_capture.py
Normal file
118
src/capture/adb_capture.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""L1 - Screen Capture via ADB
|
||||
|
||||
Captures screenshots from Android device using ADB.
|
||||
Handles device connection, screenshot acquisition, and resolution detection.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
class ADBCapture:
|
||||
"""ADB-based screen capture for Android devices."""
|
||||
|
||||
def __init__(self):
|
||||
self.adb = settings.adb_path
|
||||
self.serial = settings.device_serial
|
||||
self.screenshot_dir = Path(settings.screenshot_dir)
|
||||
self.screenshot_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._resolution: tuple[int, int] | None = None
|
||||
|
||||
def _adb_cmd(self, *args: str) -> list[str]:
|
||||
cmd = [self.adb]
|
||||
if self.serial:
|
||||
cmd.extend(["-s", self.serial])
|
||||
cmd.extend(args)
|
||||
return cmd
|
||||
|
||||
def check_device(self) -> dict:
|
||||
"""Check if device is connected and return device info."""
|
||||
result = subprocess.run(
|
||||
self._adb_cmd("devices"),
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
lines = result.stdout.strip().split("\n")[1:] # skip header
|
||||
devices = []
|
||||
for line in lines:
|
||||
parts = line.strip().split("\t")
|
||||
if len(parts) == 2 and parts[1] == "device":
|
||||
devices.append(parts[0])
|
||||
|
||||
if not devices:
|
||||
return {"connected": False, "error": "No device found"}
|
||||
|
||||
serial = self.serial or devices[0]
|
||||
if not self.serial:
|
||||
self.serial = serial
|
||||
|
||||
# Get device model
|
||||
model_result = subprocess.run(
|
||||
self._adb_cmd("shell", "getprop", "ro.product.model"),
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
model = model_result.stdout.strip()
|
||||
|
||||
# Get screen resolution
|
||||
w, h = self.get_resolution()
|
||||
|
||||
return {
|
||||
"connected": True,
|
||||
"serial": serial,
|
||||
"model": model,
|
||||
"resolution": f"{w}x{h}",
|
||||
"all_devices": devices,
|
||||
}
|
||||
|
||||
def get_resolution(self) -> tuple[int, int]:
|
||||
"""Get device screen resolution."""
|
||||
if self._resolution:
|
||||
return self._resolution
|
||||
|
||||
result = subprocess.run(
|
||||
self._adb_cmd("shell", "wm", "size"),
|
||||
capture_output=True, text=True, timeout=5
|
||||
)
|
||||
# Output: "Physical size: 1080x2400"
|
||||
size_str = result.stdout.strip().split(":")[-1].strip()
|
||||
w, h = size_str.split("x")
|
||||
self._resolution = (int(w), int(h))
|
||||
return self._resolution
|
||||
|
||||
def screenshot(self, save: bool = True) -> Image.Image:
|
||||
"""Take a screenshot and return as PIL Image.
|
||||
|
||||
Args:
|
||||
save: Whether to save the screenshot to disk for debugging.
|
||||
|
||||
Returns:
|
||||
PIL Image of the current screen.
|
||||
"""
|
||||
result = subprocess.run(
|
||||
self._adb_cmd("exec-out", "screencap", "-p"),
|
||||
capture_output=True, timeout=settings.screenshot_timeout
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Screenshot failed: {result.stderr.decode()}")
|
||||
|
||||
img = Image.open(io.BytesIO(result.stdout))
|
||||
|
||||
if save:
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
path = self.screenshot_dir / f"{ts}.png"
|
||||
img.save(path)
|
||||
|
||||
return img
|
||||
|
||||
def screenshot_base64(self) -> str:
|
||||
"""Take screenshot and return as base64-encoded PNG string."""
|
||||
import base64
|
||||
img = self.screenshot(save=True)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
3
src/executor/__init__.py
Normal file
3
src/executor/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .adb_executor import ADBExecutor
|
||||
|
||||
__all__ = ["ADBExecutor"]
|
||||
109
src/executor/adb_executor.py
Normal file
109
src/executor/adb_executor.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""L5 - Action Execution via ADB
|
||||
|
||||
Translates structured actions into ADB commands and executes them on device.
|
||||
Coordinates are normalized (0-1), converted to device pixels at execution time.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
"""A single GUI action to execute."""
|
||||
type: str # tap, swipe, type, long_press, back, home, scroll, wait
|
||||
x: float = 0.0 # normalized x (0-1)
|
||||
y: float = 0.0 # normalized y (0-1)
|
||||
text: str = "" # for type action
|
||||
x2: float = 0.0 # for swipe end
|
||||
y2: float = 0.0 # for swipe end
|
||||
duration: int = 300 # ms, for long_press and swipe
|
||||
|
||||
|
||||
class ADBExecutor:
|
||||
"""Execute actions on Android device via ADB."""
|
||||
|
||||
def __init__(self, capture):
|
||||
self.capture = capture
|
||||
self.adb = settings.adb_path
|
||||
self.serial = settings.device_serial
|
||||
|
||||
def _adb_cmd(self, *args: str) -> list[str]:
|
||||
cmd = [self.adb]
|
||||
if self.serial:
|
||||
cmd.extend(["-s", self.serial])
|
||||
cmd.extend(args)
|
||||
return cmd
|
||||
|
||||
def _run(self, *args: str):
|
||||
cmd = self._adb_cmd(*args)
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ADB command failed: {' '.join(cmd)}\n{result.stderr}")
|
||||
return result.stdout
|
||||
|
||||
def _to_pixels(self, x: float, y: float) -> tuple[int, int]:
|
||||
"""Convert normalized (0-1) coordinates to device pixels."""
|
||||
w, h = self.capture.get_resolution()
|
||||
return int(x * w), int(y * h)
|
||||
|
||||
def execute(self, action: Action) -> str:
|
||||
"""Execute a single action and return a description of what was done."""
|
||||
match action.type:
|
||||
case "tap":
|
||||
px, py = self._to_pixels(action.x, action.y)
|
||||
self._run("shell", "input", "tap", str(px), str(py))
|
||||
desc = f"tap ({px}, {py})"
|
||||
|
||||
case "long_press":
|
||||
px, py = self._to_pixels(action.x, action.y)
|
||||
self._run("shell", "input", "swipe",
|
||||
str(px), str(py), str(px), str(py), str(action.duration))
|
||||
desc = f"long_press ({px}, {py}) {action.duration}ms"
|
||||
|
||||
case "swipe":
|
||||
px1, py1 = self._to_pixels(action.x, action.y)
|
||||
px2, py2 = self._to_pixels(action.x2, action.y2)
|
||||
self._run("shell", "input", "swipe",
|
||||
str(px1), str(py1), str(px2), str(py2), str(action.duration))
|
||||
desc = f"swipe ({px1},{py1}) → ({px2},{py2})"
|
||||
|
||||
case "type":
|
||||
# Escape special characters for ADB
|
||||
escaped = action.text.replace(" ", "%s").replace("&", "\\&")
|
||||
self._run("shell", "input", "text", escaped)
|
||||
desc = f"type '{action.text}'"
|
||||
|
||||
case "back":
|
||||
self._run("shell", "input", "keyevent", "KEYCODE_BACK")
|
||||
desc = "back"
|
||||
|
||||
case "home":
|
||||
self._run("shell", "input", "keyevent", "KEYCODE_HOME")
|
||||
desc = "home"
|
||||
|
||||
case "scroll":
|
||||
# Scroll direction: swipe center screen
|
||||
px, py = self._to_pixels(0.5, 0.5)
|
||||
if action.y < 0: # scroll up
|
||||
self._run("shell", "input", "swipe",
|
||||
str(px), str(py - 300), str(px), str(py + 300), "300")
|
||||
desc = "scroll up"
|
||||
else: # scroll down
|
||||
self._run("shell", "input", "swipe",
|
||||
str(px), str(py + 300), str(px), str(py - 300), "300")
|
||||
desc = "scroll down"
|
||||
|
||||
case "wait":
|
||||
time.sleep(action.duration / 1000)
|
||||
desc = f"wait {action.duration}ms"
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type: {action.type}")
|
||||
|
||||
# Wait for UI to settle after action
|
||||
time.sleep(settings.action_delay)
|
||||
return desc
|
||||
3
src/grounding/__init__.py
Normal file
3
src/grounding/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .ocr_grounding import OCRGrounding
|
||||
|
||||
__all__ = ["OCRGrounding"]
|
||||
354
src/grounding/ocr_grounding.py
Normal file
354
src/grounding/ocr_grounding.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""L3 - OCR-Based UI Element Grounding
|
||||
|
||||
Locates UI elements on screen by visible text using OCR on ADB screenshots.
|
||||
Provides reliable text-to-coordinate mapping that works on Huawei/HarmonyOS
|
||||
where uiautomator dump often returns empty XML for WeChat.
|
||||
|
||||
Strategy priority (auto mode):
|
||||
1. easyocr (best Chinese recognition, deep learning based)
|
||||
2. pytesseract (fallback, fast but fragments Chinese characters)
|
||||
3. uiautomator XML dump (supplementary, often empty on Huawei WeChat)
|
||||
|
||||
All coordinates returned as normalized (0.0-1.0) for consistency with the
|
||||
existing coordinate system in adb_executor.py.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
import io
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
|
||||
from config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextBox:
|
||||
"""A detected text region on screen."""
|
||||
text: str
|
||||
x: int # left pixel
|
||||
y: int # top pixel
|
||||
w: int # width pixels
|
||||
h: int # height pixels
|
||||
confidence: float # 0.0-1.0
|
||||
|
||||
@property
|
||||
def cx(self) -> int:
|
||||
"""Center x in pixels."""
|
||||
return self.x + self.w // 2
|
||||
|
||||
@property
|
||||
def cy(self) -> int:
|
||||
"""Center y in pixels."""
|
||||
return self.y + self.h // 2
|
||||
|
||||
def center_normalized(self, screen_w: int, screen_h: int) -> tuple[float, float]:
|
||||
"""Return center as normalized (0-1) coordinates."""
|
||||
return self.cx / screen_w, self.cy / screen_h
|
||||
|
||||
def contains_text(self, query: str, fuzzy: bool = True) -> bool:
|
||||
"""Check if this box's text matches the query.
|
||||
|
||||
Args:
|
||||
query: Text to search for.
|
||||
fuzzy: If True, does substring + case-insensitive match.
|
||||
"""
|
||||
if not query or not self.text:
|
||||
return False
|
||||
if fuzzy:
|
||||
return query.lower() in self.text.lower() or self.text.lower() in query.lower()
|
||||
return self.text == query
|
||||
|
||||
def match_score(self, query: str) -> float:
|
||||
"""Compute a match quality score (higher = better).
|
||||
|
||||
Scoring:
|
||||
- Exact match: 1000 + confidence
|
||||
- Query is full text: 500 + confidence
|
||||
- Text contains query as substring: 100 + confidence + length_ratio
|
||||
- Query contains text as substring: 50 + confidence
|
||||
- No match: 0
|
||||
"""
|
||||
if not query or not self.text:
|
||||
return 0.0
|
||||
|
||||
q = query.lower()
|
||||
t = self.text.lower().strip()
|
||||
|
||||
if t == q:
|
||||
return 1000 + self.confidence
|
||||
if q in t:
|
||||
# Prefer shorter texts that contain the query (more precise)
|
||||
length_ratio = len(q) / max(len(t), 1)
|
||||
return 100 + self.confidence + length_ratio
|
||||
if t in q:
|
||||
# Text is a subset of query -- weaker match
|
||||
length_ratio = len(t) / max(len(q), 1)
|
||||
return 50 + self.confidence * length_ratio
|
||||
return 0.0
|
||||
|
||||
|
||||
class OCRGrounding:
|
||||
"""OCR-based element grounding for Android screens.
|
||||
|
||||
Usage:
|
||||
grounding = OCRGrounding()
|
||||
|
||||
# From ADB screenshot (PIL Image)
|
||||
img = capture.screenshot()
|
||||
result = grounding.find_text(img, "发送")
|
||||
if result:
|
||||
norm_x, norm_y = result.center_normalized(img.width, img.height)
|
||||
# Use norm_x, norm_y with ADBExecutor
|
||||
"""
|
||||
|
||||
def __init__(self, engine: str = "auto"):
|
||||
"""
|
||||
Args:
|
||||
engine: OCR engine to use.
|
||||
"pytesseract" / "easyocr" / "auto" (easyocr first, pytesseract fallback)
|
||||
"""
|
||||
self.engine = engine
|
||||
self._easyocr_reader = None # lazy init (slow first load)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Public API
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def find_text(
|
||||
self, img: Image.Image, query: str, fuzzy: bool = True
|
||||
) -> TextBox | None:
|
||||
"""Find a UI element by visible text and return its bounding box.
|
||||
|
||||
Args:
|
||||
img: PIL Image (screenshot from ADB).
|
||||
query: Text to search for (e.g. "发送", "微信", "Search").
|
||||
fuzzy: Substring/case-insensitive match.
|
||||
|
||||
Returns:
|
||||
Best matching TextBox, or None if not found.
|
||||
"""
|
||||
boxes = self.detect_all(img)
|
||||
matches = [b for b in boxes if b.contains_text(query, fuzzy=fuzzy)]
|
||||
|
||||
if not matches:
|
||||
logger.warning(f"Text '{query}' not found. Detected texts: "
|
||||
f"{[b.text for b in boxes[:20]]}")
|
||||
return None
|
||||
|
||||
# Return best match by match_score (prefers exact/longer matches)
|
||||
matches.sort(key=lambda b: b.match_score(query), reverse=True)
|
||||
best = matches[0]
|
||||
logger.info(f"Found '{query}' → '{best.text}' at ({best.cx}, {best.cy}) "
|
||||
f"conf={best.confidence:.2f} score={best.match_score(query):.1f}")
|
||||
return best
|
||||
|
||||
def find_all_matches(
|
||||
self, img: Image.Image, query: str, fuzzy: bool = True
|
||||
) -> list[TextBox]:
|
||||
"""Find ALL matching elements (e.g., multiple chat contacts named similar)."""
|
||||
boxes = self.detect_all(img)
|
||||
return [b for b in boxes if b.contains_text(query, fuzzy=fuzzy)]
|
||||
|
||||
def detect_all(self, img: Image.Image) -> list[TextBox]:
|
||||
"""Run OCR on the full image and return all detected text boxes.
|
||||
|
||||
Tries engines in order based on self.engine setting.
|
||||
"""
|
||||
if self.engine == "pytesseract":
|
||||
return self._detect_pytesseract(img)
|
||||
elif self.engine == "easyocr":
|
||||
return self._detect_easyocr(img)
|
||||
else: # auto
|
||||
# Prefer easyocr (much better Chinese recognition), fall back to pytesseract
|
||||
try:
|
||||
return self._detect_easyocr(img)
|
||||
except Exception as e:
|
||||
logger.info(f"easyocr failed ({e}), trying pytesseract")
|
||||
|
||||
try:
|
||||
boxes = self._detect_pytesseract(img)
|
||||
if boxes:
|
||||
return boxes
|
||||
except Exception as e:
|
||||
logger.error(f"All OCR engines failed: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def find_text_normalized(
|
||||
self, img: Image.Image, query: str, fuzzy: bool = True
|
||||
) -> tuple[float, float] | None:
|
||||
"""Convenience: find text and return normalized (x, y) center directly.
|
||||
|
||||
Returns None if not found.
|
||||
"""
|
||||
box = self.find_text(img, query, fuzzy=fuzzy)
|
||||
if box is None:
|
||||
return None
|
||||
return box.center_normalized(img.width, img.height)
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# pytesseract engine
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def _detect_pytesseract(self, img: Image.Image) -> list[TextBox]:
|
||||
"""Detect text using pytesseract (calls tesseract binary).
|
||||
|
||||
Uses chi_sim+eng for Chinese + English mixed content (common in WeChat).
|
||||
Falls back to eng-only if chi_sim data is not installed.
|
||||
"""
|
||||
import pytesseract
|
||||
|
||||
# Try Chinese+English first, fall back to English only
|
||||
for lang in ["chi_sim+eng", "eng"]:
|
||||
try:
|
||||
data = pytesseract.image_to_data(
|
||||
img,
|
||||
lang=lang,
|
||||
output_type=pytesseract.Output.DICT,
|
||||
config="--psm 11" # Sparse text: find as much text as possible
|
||||
)
|
||||
break
|
||||
except pytesseract.TesseractError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError("Tesseract failed with all language configs")
|
||||
|
||||
boxes = []
|
||||
n = len(data["text"])
|
||||
for i in range(n):
|
||||
text = data["text"][i].strip()
|
||||
conf = int(data["conf"][i])
|
||||
if not text or conf < 20: # skip low-confidence noise
|
||||
continue
|
||||
boxes.append(TextBox(
|
||||
text=text,
|
||||
x=data["left"][i],
|
||||
y=data["top"][i],
|
||||
w=data["width"][i],
|
||||
h=data["height"][i],
|
||||
confidence=conf / 100.0,
|
||||
))
|
||||
|
||||
return boxes
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# easyocr engine
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def _detect_easyocr(self, img: Image.Image) -> list[TextBox]:
|
||||
"""Detect text using easyocr (better for Chinese, uses deep learning).
|
||||
|
||||
First call is slow (~10s) due to model loading. Subsequent calls are fast.
|
||||
"""
|
||||
import easyocr
|
||||
import numpy as np
|
||||
|
||||
if self._easyocr_reader is None:
|
||||
self._easyocr_reader = easyocr.Reader(
|
||||
["ch_sim", "en"],
|
||||
gpu=False, # CPU is fine for single screenshots
|
||||
)
|
||||
|
||||
# Convert PIL to numpy array for easyocr
|
||||
img_np = np.array(img.convert("RGB"))
|
||||
results = self._easyocr_reader.readtext(img_np)
|
||||
|
||||
boxes = []
|
||||
for (bbox, text, conf) in results:
|
||||
if not text.strip():
|
||||
continue
|
||||
# bbox is [[x1,y1],[x2,y2],[x3,y3],[x4,y4]] (quadrilateral)
|
||||
xs = [p[0] for p in bbox]
|
||||
ys = [p[1] for p in bbox]
|
||||
x = int(min(xs))
|
||||
y = int(min(ys))
|
||||
w = int(max(xs) - x)
|
||||
h = int(max(ys) - y)
|
||||
boxes.append(TextBox(
|
||||
text=text.strip(),
|
||||
x=x, y=y, w=w, h=h,
|
||||
confidence=float(conf),
|
||||
))
|
||||
|
||||
return boxes
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# uiautomator XML dump (supplementary, often empty on Huawei)
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def try_uiautomator_dump(self, serial: str | None = None) -> list[TextBox]:
|
||||
"""Attempt to get UI elements from uiautomator dump.
|
||||
|
||||
NOTE: This often returns nearly empty XML on Huawei/HarmonyOS,
|
||||
especially for WeChat. Use as a supplementary source, not primary.
|
||||
|
||||
Args:
|
||||
serial: Device serial (None = use settings or first device).
|
||||
|
||||
Returns:
|
||||
List of TextBox from accessibility tree, may be empty.
|
||||
"""
|
||||
adb = settings.adb_path
|
||||
cmd = [adb]
|
||||
if serial or settings.device_serial:
|
||||
cmd.extend(["-s", serial or settings.device_serial])
|
||||
|
||||
# Dump to device, then pull
|
||||
dump_cmd = cmd + ["shell", "uiautomator", "dump", "/sdcard/ui_dump.xml"]
|
||||
pull_cmd = cmd + ["shell", "cat", "/sdcard/ui_dump.xml"]
|
||||
|
||||
try:
|
||||
subprocess.run(dump_cmd, capture_output=True, timeout=10)
|
||||
result = subprocess.run(pull_cmd, capture_output=True, text=True, timeout=5)
|
||||
xml_content = result.stdout
|
||||
except Exception as e:
|
||||
logger.warning(f"uiautomator dump failed: {e}")
|
||||
return []
|
||||
|
||||
return self._parse_uiautomator_xml(xml_content)
|
||||
|
||||
def _parse_uiautomator_xml(self, xml_str: str) -> list[TextBox]:
|
||||
"""Parse uiautomator dump XML into TextBox list."""
|
||||
boxes = []
|
||||
# Pattern: text="..." bounds="[x1,y1][x2,y2]"
|
||||
pattern = r'text="([^"]*)"[^>]*bounds="\[(\d+),(\d+)\]\[(\d+),(\d+)\]"'
|
||||
for match in re.finditer(pattern, xml_str):
|
||||
text = match.group(1).strip()
|
||||
if not text:
|
||||
continue
|
||||
x1, y1, x2, y2 = (int(match.group(i)) for i in range(2, 6))
|
||||
boxes.append(TextBox(
|
||||
text=text,
|
||||
x=x1, y=y1,
|
||||
w=x2 - x1, h=y2 - y1,
|
||||
confidence=1.0, # accessibility tree is authoritative
|
||||
))
|
||||
return boxes
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Hybrid: combine OCR + uiautomator
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
def find_text_hybrid(
|
||||
self, img: Image.Image, query: str, fuzzy: bool = True
|
||||
) -> TextBox | None:
|
||||
"""Try uiautomator first (exact bounds), fall back to OCR.
|
||||
|
||||
Best strategy for Huawei: uiautomator might work for some apps,
|
||||
OCR always works as fallback.
|
||||
"""
|
||||
# Try uiautomator first (precise but often empty on Huawei)
|
||||
ua_boxes = self.try_uiautomator_dump()
|
||||
ua_matches = [b for b in ua_boxes if b.contains_text(query, fuzzy=fuzzy)]
|
||||
if ua_matches:
|
||||
logger.info(f"Found '{query}' via uiautomator")
|
||||
return ua_matches[0]
|
||||
|
||||
# Fall back to OCR
|
||||
logger.info(f"uiautomator found nothing for '{query}', using OCR")
|
||||
return self.find_text(img, query, fuzzy=fuzzy)
|
||||
122
src/main.py
Normal file
122
src/main.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Phone GUI Agent - Main Entry Point
|
||||
|
||||
Web console for controlling the agent loop.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
|
||||
from config import settings
|
||||
from src.capture import ADBCapture
|
||||
from src.planner.agent_loop import AgentLoop
|
||||
|
||||
app = FastAPI(title="Phone GUI Agent", version="0.1.0")
|
||||
|
||||
BASE_DIR = Path(__file__).parent.parent
|
||||
app.mount("/static", StaticFiles(directory=BASE_DIR / "web" / "static"), name="static")
|
||||
templates = Jinja2Templates(directory=BASE_DIR / "web" / "templates")
|
||||
|
||||
# Global state
|
||||
capture = ADBCapture()
|
||||
agent = AgentLoop()
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
return templates.TemplateResponse(request, "index.html")
|
||||
|
||||
|
||||
@app.get("/api/device")
|
||||
async def device_info():
|
||||
"""Check device connection status."""
|
||||
try:
|
||||
info = capture.check_device()
|
||||
return info
|
||||
except Exception as e:
|
||||
return {"connected": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.get("/api/screenshot")
|
||||
async def take_screenshot():
|
||||
"""Take a screenshot and return base64."""
|
||||
try:
|
||||
b64 = capture.screenshot_base64()
|
||||
return {"ok": True, "image": b64}
|
||||
except Exception as e:
|
||||
return {"ok": False, "error": str(e)}
|
||||
|
||||
|
||||
@app.post("/api/stop")
|
||||
async def stop_task():
|
||||
"""Stop the current running task."""
|
||||
agent.stop()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@app.websocket("/ws/task")
|
||||
async def task_websocket(ws: WebSocket):
|
||||
"""WebSocket endpoint for running tasks with real-time updates.
|
||||
|
||||
Client sends: {"task": "打开微信搜索张三"}
|
||||
Server streams: StepResult objects as JSON
|
||||
"""
|
||||
await ws.accept()
|
||||
try:
|
||||
data = await ws.receive_json()
|
||||
task = data.get("task", "")
|
||||
if not task:
|
||||
await ws.send_json({"error": "No task provided"})
|
||||
return
|
||||
|
||||
await ws.send_json({"status": "started", "task": task})
|
||||
|
||||
def on_step(result):
|
||||
asyncio.get_event_loop().call_soon_threadsafe(
|
||||
asyncio.ensure_future,
|
||||
ws.send_json({
|
||||
"status": "step",
|
||||
"step": result.step,
|
||||
"observation": result.observation,
|
||||
"thinking": result.thinking,
|
||||
"action_type": result.action_type,
|
||||
"action_desc": result.action_desc,
|
||||
"screenshot": result.screenshot_before[:100] + "..." if result.screenshot_before else None,
|
||||
"error": result.error,
|
||||
})
|
||||
)
|
||||
|
||||
session = await agent.run_task(task, on_step=on_step)
|
||||
|
||||
await ws.send_json({
|
||||
"status": session.status,
|
||||
"total_steps": len(session.steps),
|
||||
"task": task,
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
agent.stop()
|
||||
except Exception as e:
|
||||
try:
|
||||
await ws.send_json({"error": str(e)})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"src.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
src/planner/__init__.py
Normal file
3
src/planner/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent_loop import AgentLoop
|
||||
|
||||
__all__ = ["AgentLoop"]
|
||||
200
src/planner/agent_loop.py
Normal file
200
src/planner/agent_loop.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""L4+L6+L7 - Agent Loop: Planning, Verification, Memory
|
||||
|
||||
The core agent loop that orchestrates the full pipeline:
|
||||
Screenshot → VLM Analysis → Action Execution → Verification → Repeat
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
||||
from src.capture import ADBCapture
|
||||
from src.vision import VLMClient
|
||||
from src.executor.adb_executor import ADBExecutor, Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
step: int
|
||||
timestamp: str
|
||||
observation: str
|
||||
thinking: str
|
||||
action_type: str
|
||||
action_desc: str
|
||||
screenshot_before: str # base64
|
||||
screenshot_after: str | None = None
|
||||
verified: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskSession:
|
||||
task: str
|
||||
status: str = "running" # running / completed / failed / stopped
|
||||
steps: list[StepResult] = field(default_factory=list)
|
||||
started_at: str = ""
|
||||
finished_at: str = ""
|
||||
|
||||
def history(self) -> list[dict]:
|
||||
"""Return history for VLM context."""
|
||||
return [
|
||||
{
|
||||
"observation": s.observation,
|
||||
"action": {"type": s.action_type},
|
||||
}
|
||||
for s in self.steps
|
||||
]
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""Main agent loop orchestrating all pipeline layers."""
|
||||
|
||||
def __init__(self):
|
||||
self.capture = ADBCapture()
|
||||
self.vlm = VLMClient()
|
||||
self.executor = ADBExecutor(self.capture)
|
||||
self.current_session: TaskSession | None = None
|
||||
self._stop_requested = False
|
||||
|
||||
def stop(self):
|
||||
self._stop_requested = True
|
||||
|
||||
async def run_task(self, task: str, on_step=None) -> TaskSession:
|
||||
"""Execute a task through the full agent loop.
|
||||
|
||||
Args:
|
||||
task: Natural language task instruction.
|
||||
on_step: Optional callback called after each step with StepResult.
|
||||
|
||||
Returns:
|
||||
TaskSession with all steps and final status.
|
||||
"""
|
||||
from config import settings
|
||||
|
||||
session = TaskSession(
|
||||
task=task,
|
||||
started_at=datetime.now().isoformat(),
|
||||
)
|
||||
self.current_session = session
|
||||
self._stop_requested = False
|
||||
|
||||
try:
|
||||
for step_num in range(1, settings.max_steps + 1):
|
||||
if self._stop_requested:
|
||||
session.status = "stopped"
|
||||
break
|
||||
|
||||
result = await self._execute_step(step_num, task, session)
|
||||
session.steps.append(result)
|
||||
|
||||
if on_step:
|
||||
on_step(result)
|
||||
|
||||
if result.action_type == "done":
|
||||
session.status = "completed"
|
||||
break
|
||||
|
||||
if result.error:
|
||||
# Allow up to 3 consecutive errors before failing
|
||||
recent_errors = sum(
|
||||
1 for s in session.steps[-3:] if s.error
|
||||
)
|
||||
if recent_errors >= 3:
|
||||
session.status = "failed"
|
||||
break
|
||||
else:
|
||||
session.status = "failed" # max steps exceeded
|
||||
|
||||
except Exception as e:
|
||||
session.status = "failed"
|
||||
if session.steps:
|
||||
session.steps[-1].error = str(e)
|
||||
|
||||
session.finished_at = datetime.now().isoformat()
|
||||
self.current_session = None
|
||||
return session
|
||||
|
||||
async def _execute_step(
|
||||
self, step_num: int, task: str, session: TaskSession
|
||||
) -> StepResult:
|
||||
"""Execute a single step in the agent loop."""
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# L1: Capture screenshot
|
||||
try:
|
||||
screenshot_b64 = self.capture.screenshot_base64()
|
||||
except Exception as e:
|
||||
return StepResult(
|
||||
step=step_num, timestamp=timestamp,
|
||||
observation="", thinking="",
|
||||
action_type="error", action_desc="",
|
||||
screenshot_before="", error=f"Screenshot failed: {e}"
|
||||
)
|
||||
|
||||
# L2+L3+L4: VLM analysis (understanding + grounding + planning)
|
||||
try:
|
||||
response = await self.vlm.analyze_screen(
|
||||
screenshot_b64, task, session.history()
|
||||
)
|
||||
except Exception as e:
|
||||
return StepResult(
|
||||
step=step_num, timestamp=timestamp,
|
||||
observation="", thinking="",
|
||||
action_type="error", action_desc="",
|
||||
screenshot_before=screenshot_b64,
|
||||
error=f"VLM analysis failed: {e}"
|
||||
)
|
||||
|
||||
observation = response.get("observation", "")
|
||||
thinking = response.get("thinking", "")
|
||||
action_data = response["action"]
|
||||
action_type = action_data["type"]
|
||||
|
||||
# Task complete
|
||||
if action_type == "done":
|
||||
return StepResult(
|
||||
step=step_num, timestamp=timestamp,
|
||||
observation=observation, thinking=thinking,
|
||||
action_type="done", action_desc="Task completed",
|
||||
screenshot_before=screenshot_b64,
|
||||
)
|
||||
|
||||
# L5: Execute action
|
||||
action = Action(
|
||||
type=action_type,
|
||||
x=action_data.get("x", 0),
|
||||
y=action_data.get("y", 0),
|
||||
text=action_data.get("text", ""),
|
||||
x2=action_data.get("x2", 0),
|
||||
y2=action_data.get("y2", 0),
|
||||
duration=action_data.get("duration", 300),
|
||||
)
|
||||
|
||||
try:
|
||||
action_desc = self.executor.execute(action)
|
||||
except Exception as e:
|
||||
return StepResult(
|
||||
step=step_num, timestamp=timestamp,
|
||||
observation=observation, thinking=thinking,
|
||||
action_type=action_type, action_desc="",
|
||||
screenshot_before=screenshot_b64,
|
||||
error=f"Execution failed: {e}"
|
||||
)
|
||||
|
||||
# L6: Verify by taking post-action screenshot
|
||||
screenshot_after = None
|
||||
if settings.verify_after_action:
|
||||
try:
|
||||
screenshot_after = self.capture.screenshot_base64()
|
||||
except Exception:
|
||||
pass # non-critical
|
||||
|
||||
return StepResult(
|
||||
step=step_num, timestamp=timestamp,
|
||||
observation=observation, thinking=thinking,
|
||||
action_type=action_type, action_desc=action_desc,
|
||||
screenshot_before=screenshot_b64,
|
||||
screenshot_after=screenshot_after,
|
||||
verified=screenshot_after is not None,
|
||||
)
|
||||
0
src/verifier/__init__.py
Normal file
0
src/verifier/__init__.py
Normal file
3
src/vision/__init__.py
Normal file
3
src/vision/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .vlm_client import VLMClient
|
||||
|
||||
__all__ = ["VLMClient"]
|
||||
171
src/vision/vlm_client.py
Normal file
171
src/vision/vlm_client.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""L2+L3 - Vision Language Model Client
|
||||
|
||||
Sends screenshots to VLM for screen understanding and element grounding.
|
||||
Supports multiple providers: Poe API (preferred), OpenRouter (backup), local.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import httpx
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from config import settings
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """你是一个手机 GUI 操控助手。你会收到一张 Android 手机截图和一个用户任务指令。
|
||||
|
||||
你的职责:
|
||||
1. 分析当前屏幕内容(识别所有 UI 元素、文本、图标、按钮)
|
||||
2. 根据任务目标,决定下一步要执行的操作
|
||||
3. 精确定位目标元素的屏幕坐标
|
||||
|
||||
输出格式(严格 JSON):
|
||||
{
|
||||
"observation": "当前屏幕的简要描述",
|
||||
"thinking": "下一步应该做什么,为什么",
|
||||
"action": {
|
||||
"type": "tap|swipe|type|long_press|back|home|scroll|wait|done",
|
||||
"x": 0.5,
|
||||
"y": 0.3,
|
||||
"text": "",
|
||||
"x2": 0.0,
|
||||
"y2": 0.0,
|
||||
"duration": 300
|
||||
}
|
||||
}
|
||||
|
||||
坐标说明:
|
||||
- x, y 为归一化坐标,范围 0.0-1.0
|
||||
- (0, 0) 是屏幕左上角,(1, 1) 是右下角
|
||||
- 点击按钮时,坐标应指向按钮的中心位置
|
||||
|
||||
当任务完成时,action.type 设为 "done"。
|
||||
"""
|
||||
|
||||
|
||||
class VLMClient:
|
||||
"""Multi-provider VLM client for screen understanding."""
|
||||
|
||||
def __init__(self):
|
||||
self.provider = settings.vlm_provider
|
||||
self.model = settings.vlm_model
|
||||
|
||||
async def analyze_screen(
|
||||
self, screenshot_b64: str, task: str, history: list[dict] | None = None
|
||||
) -> dict:
|
||||
"""Send screenshot to VLM and get structured action response.
|
||||
|
||||
Args:
|
||||
screenshot_b64: Base64-encoded PNG screenshot.
|
||||
task: User's task instruction.
|
||||
history: Previous observation/action pairs for context.
|
||||
|
||||
Returns:
|
||||
Parsed dict with observation, thinking, and action.
|
||||
"""
|
||||
messages = self._build_messages(screenshot_b64, task, history)
|
||||
|
||||
match self.provider:
|
||||
case "poe":
|
||||
raw = await self._call_poe(messages)
|
||||
case "openrouter":
|
||||
raw = await self._call_openrouter(messages)
|
||||
case "local":
|
||||
raw = await self._call_local(messages)
|
||||
case _:
|
||||
raise ValueError(f"Unknown VLM provider: {self.provider}")
|
||||
|
||||
return self._parse_response(raw)
|
||||
|
||||
def _build_messages(
|
||||
self, screenshot_b64: str, task: str, history: list[dict] | None
|
||||
) -> list[dict]:
|
||||
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
# Add history context
|
||||
if history:
|
||||
history_text = "\n".join(
|
||||
f"Step {i+1}: {h['observation']} → {h['action']['type']}"
|
||||
for i, h in enumerate(history[-5:]) # last 5 steps
|
||||
)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"历史操作记录:\n{history_text}"
|
||||
})
|
||||
|
||||
# Current step: screenshot + task
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{screenshot_b64}"}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"当前任务:{task}\n\n请分析截图并给出下一步操作。"
|
||||
},
|
||||
],
|
||||
})
|
||||
return messages
|
||||
|
||||
async def _call_poe(self, messages: list[dict]) -> str:
|
||||
"""Call Poe API (preferred, cheapest)."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
"https://api.poe.com/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {settings.poe_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": self.model, "messages": messages},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["message"]["content"]
|
||||
|
||||
async def _call_openrouter(self, messages: list[dict]) -> str:
|
||||
"""Call OpenRouter API (backup)."""
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
resp = await client.post(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {settings.openrouter_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={"model": self.model, "messages": messages},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["message"]["content"]
|
||||
|
||||
async def _call_local(self, messages: list[dict]) -> str:
|
||||
"""Call local vLLM/Ollama server."""
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(
|
||||
"http://localhost:11434/v1/chat/completions",
|
||||
json={"model": self.model, "messages": messages},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["message"]["content"]
|
||||
|
||||
def _parse_response(self, raw: str) -> dict:
|
||||
"""Parse VLM response into structured action dict."""
|
||||
import json
|
||||
import re
|
||||
|
||||
# Extract JSON from response (handle markdown code blocks)
|
||||
json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", raw, re.DOTALL)
|
||||
if json_match:
|
||||
raw = json_match.group(1)
|
||||
|
||||
# Try to find JSON object directly
|
||||
json_match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(f"No JSON found in VLM response: {raw[:200]}")
|
||||
|
||||
parsed = json.loads(json_match.group())
|
||||
|
||||
# Validate required fields
|
||||
assert "action" in parsed, "Missing 'action' field"
|
||||
assert "type" in parsed["action"], "Missing action 'type'"
|
||||
|
||||
return parsed
|
||||
Reference in New Issue
Block a user