auto-save 2026-05-13 14:43 (~6)
This commit is contained in:
69
api/main.py
69
api/main.py
@@ -1267,10 +1267,12 @@ def delete_element(job_id: str, idx: int, element_id: str) -> Job:
|
||||
|
||||
@app.post("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout", response_model=Job)
|
||||
def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||||
"""提取元素 · 每次调用累积一张新图(不覆盖之前的):
|
||||
- 有 region → PIL crop(瞬时 · 保留原表情/形体)
|
||||
- 无 region → 调 nano-banana 模型生成白底图(5-15s)"""
|
||||
"""AI 提取元素 · 每次累积一张新图:
|
||||
调 nano-banana 模型生成**完整、清晰**的元素图(即使原图只露出部分也补全)。
|
||||
region 元素:先把 region + 30% padding 区域裁出作为 focus,再发给模型聚焦补全。"""
|
||||
from PIL import Image as _PILImage
|
||||
import io as _io
|
||||
import tempfile as _tempfile
|
||||
job = JOBS.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(404, "job not found")
|
||||
@@ -1288,10 +1290,12 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||||
|
||||
out_dir = job_dir(job_id) / "elements"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
# 新建一个 cutout_id append 到 element.cutouts(而非覆盖)
|
||||
new_cutout_id = uuid.uuid4().hex[:8]
|
||||
out_path = out_dir / f"{idx:03d}_{element_id}_{new_cutout_id}.jpg"
|
||||
|
||||
# region 元素:先 PIL 裁出 region + 30% padding 作为 focus 给模型(让它聚焦在该元素)
|
||||
tmp_focus: Path | None = None
|
||||
model_src = src
|
||||
if el.region:
|
||||
try:
|
||||
im = _PILImage.open(src).convert("RGB")
|
||||
@@ -1301,31 +1305,46 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||||
y = max(0.0, min(1.0, float(r.get("y", 0))))
|
||||
w = max(0.0, min(1.0 - x, float(r.get("w", 0))))
|
||||
h = max(0.0, min(1.0 - y, float(r.get("h", 0))))
|
||||
left, top = int(x * W), int(y * H)
|
||||
right, bottom = int((x + w) * W), int((y + h) * H)
|
||||
if right - left < 4 or bottom - top < 4:
|
||||
raise HTTPException(400, "region 太小,无法提取")
|
||||
cropped = im.crop((left, top, right, bottom))
|
||||
cropped.save(out_path, format="JPEG", quality=92)
|
||||
except HTTPException:
|
||||
raise
|
||||
cx, cy = x + w / 2, y + h / 2
|
||||
# 扩大 30% 给上下文(避免裁到正好边界丢失补全 hint)
|
||||
ew, eh = w * 1.6, h * 1.6
|
||||
x0 = max(0.0, cx - ew / 2); y0 = max(0.0, cy - eh / 2)
|
||||
x1 = min(1.0, cx + ew / 2); y1 = min(1.0, cy + eh / 2)
|
||||
left, top, right, bottom = int(x0 * W), int(y0 * H), int(x1 * W), int(y1 * H)
|
||||
if right - left > 8 and bottom - top > 8:
|
||||
cropped = im.crop((left, top, right, bottom))
|
||||
tmp = _tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
|
||||
cropped.save(tmp.name, format="JPEG", quality=92)
|
||||
tmp.close()
|
||||
tmp_focus = Path(tmp.name)
|
||||
model_src = tmp_focus
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"extract failed: {e}")
|
||||
else:
|
||||
target = (el.name_en or el.name_zh).strip()
|
||||
position_hint = f" Located in the {el.position} area." if el.position else ""
|
||||
prompt = (
|
||||
f"Extract the {target} from this image as a standalone asset.{position_hint} "
|
||||
"Place it on a pure white background, isolated, no other objects."
|
||||
)
|
||||
models = [IMAGE_MODEL, "gemini-2.5-flash-image"]
|
||||
print(f"[cutout region crop failed, fallback to full frame] {e}", flush=True)
|
||||
|
||||
target = (el.name_en or el.name_zh).strip()
|
||||
prompt = (
|
||||
f"Identify the {target} in this image. "
|
||||
f"Generate a complete, high-resolution, sharply detailed image of the entire {target} as a standalone asset. "
|
||||
f"If the {target} is only partially visible in the source (cropped at edges, occluded by other objects, or out of frame), "
|
||||
"intelligently reconstruct the missing parts based on visual context so the result shows the FULL element. "
|
||||
"Place the complete element on a pure white background, isolated, with no other objects, no scene fragments, no shadows from the original scene. "
|
||||
"Preserve the element's original color palette, style, lighting character, and proportions. "
|
||||
"Output must be a clean, high-quality asset image suitable for downstream composition."
|
||||
)
|
||||
models = [IMAGE_MODEL, "gemini-2.5-flash-image"]
|
||||
img_bytes: bytes
|
||||
try:
|
||||
try:
|
||||
img_bytes, _mode = _image_edit_call(
|
||||
src, prompt, models=models, fallback_text=False, max_attempts=3,
|
||||
model_src, prompt, models=models, fallback_text=False, max_attempts=3,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(500, f"extract failed: {e}")
|
||||
out_path.write_bytes(img_bytes)
|
||||
finally:
|
||||
if tmp_focus and tmp_focus.exists():
|
||||
try: tmp_focus.unlink()
|
||||
except OSError: pass
|
||||
out_path.write_bytes(img_bytes)
|
||||
|
||||
new_frames = []
|
||||
for f in job.frames:
|
||||
@@ -1333,12 +1352,10 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
|
||||
for e in f.elements:
|
||||
if e.id == element_id:
|
||||
e.cutouts = (e.cutouts or []) + [new_cutout_id]
|
||||
# 兼容:若旧字段 cutout_id 未设置,记一下让旧 UI 仍能读到一张
|
||||
if not e.cutout_id:
|
||||
e.cutout_id = new_cutout_id
|
||||
new_frames.append(f)
|
||||
msg_label = "提取(PIL)" if el.region else "提取(模型)"
|
||||
update(job, frames=new_frames, message=f"{msg_label}完成 · {el.name_zh}")
|
||||
update(job, frames=new_frames, message=f"提取完成 · {el.name_zh}")
|
||||
return job
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user