auto-save 2026-05-13 14:43 (~6)

This commit is contained in:
2026-05-13 14:44:00 +08:00
parent 9421836a6d
commit 59f6c16225
6 changed files with 106 additions and 69 deletions

View File

@@ -1267,10 +1267,12 @@ def delete_element(job_id: str, idx: int, element_id: str) -> Job:
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job)
def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
"""提取元素 · 每次调用累积一张新图(不覆盖之前的)
- 有 region → PIL crop瞬时 · 保留原表情/形体)
- 无 region → 调 nano-banana 模型生成白底图5-15s"""
"""AI 提取元素 · 每次累积一张新图:
调 nano-banana 模型生成**完整、清晰**的元素图(即使原图只露出部分也补全)。
region 元素:先把 region + 30% padding 区域裁出作为 focus再发给模型聚焦补全。"""
from PIL import Image as _PILImage
import io as _io
import tempfile as _tempfile
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
@@ -1288,10 +1290,12 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
out_dir = job_dir(job_id) / "elements"
out_dir.mkdir(parents=True, exist_ok=True)
# 新建一个 cutout_id append 到 element.cutouts而非覆盖
new_cutout_id = uuid.uuid4().hex[:8]
out_path = out_dir / f"{idx:03d}_{element_id}_{new_cutout_id}.jpg"
# region 元素:先 PIL 裁出 region + 30% padding 作为 focus 给模型(让它聚焦在该元素)
tmp_focus: Path | None = None
model_src = src
if el.region:
try:
im = _PILImage.open(src).convert("RGB")
@@ -1301,31 +1305,46 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
y = max(0.0, min(1.0, float(r.get("y", 0))))
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
left, top = int(x * W), int(y * H)
right, bottom = int((x + w) * W), int((y + h) * H)
if right - left < 4 or bottom - top < 4:
raise HTTPException(400, "region 太小,无法提取")
cropped = im.crop((left, top, right, bottom))
cropped.save(out_path, format="JPEG", quality=92)
except HTTPException:
raise
cx, cy = x + w / 2, y + h / 2
# 扩大 30% 给上下文(避免裁到正好边界丢失补全 hint
ew, eh = w * 1.6, h * 1.6
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
if right - left > 8 and bottom - top > 8:
cropped = im.crop((left, top, right, bottom))
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
cropped.save(tmp.name, format="JPEG", quality=92)
tmp.close()
tmp_focus = Path(tmp.name)
model_src = tmp_focus
except Exception as e:
raise HTTPException(500, f"extract failed: {e}")
else:
target = (el.name_en or el.name_zh).strip()
position_hint = f" Located in the {el.position} area." if el.position else ""
prompt = (
f"Extract the {target} from this image as a standalone asset.{position_hint} "
"Place it on a pure white background, isolated, no other objects."
)
models = [IMAGE_MODEL, "gemini-2.5-flash-image"]
print(f"[cutout region crop failed, fallback to full frame] {e}", flush=True)
target = (el.name_en or el.name_zh).strip()
prompt = (
f"Identify the {target} in this image. "
f"Generate a complete, high-resolution, sharply detailed image of the entire {target} as a standalone asset. "
f"If the {target} is only partially visible in the source (cropped at edges, occluded by other objects, or out of frame), "
"intelligently reconstruct the missing parts based on visual context so the result shows the FULL element. "
"Place the complete element on a pure white background, isolated, with no other objects, no scene fragments, no shadows from the original scene. "
"Preserve the element's original color palette, style, lighting character, and proportions. "
"Output must be a clean, high-quality asset image suitable for downstream composition."
)
models = [IMAGE_MODEL, "gemini-2.5-flash-image"]
img_bytes: bytes
try:
try:
img_bytes, _mode = _image_edit_call(
src, prompt, models=models, fallback_text=False, max_attempts=3,
model_src, prompt, models=models, fallback_text=False, max_attempts=3,
)
except RuntimeError as e:
raise HTTPException(500, f"extract failed: {e}")
out_path.write_bytes(img_bytes)
finally:
if tmp_focus and tmp_focus.exists():
try: tmp_focus.unlink()
except OSError: pass
out_path.write_bytes(img_bytes)
new_frames = []
for f in job.frames:
@@ -1333,12 +1352,10 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
for e in f.elements:
if e.id == element_id:
e.cutouts = (e.cutouts or []) + [new_cutout_id]
# 兼容:若旧字段 cutout_id 未设置,记一下让旧 UI 仍能读到一张
if not e.cutout_id:
e.cutout_id = new_cutout_id
new_frames.append(f)
msg_label = "提取PIL" if el.region else "提取(模型)"
update(job, frames=new_frames, message=f"{msg_label}完成 · {el.name_zh}")
update(job, frames=new_frames, message=f"提取完成 · {el.name_zh}")
return job