auto-save 2026-05-13 00:22 (~4)
This commit is contained in:
140
api/main.py
140
api/main.py
@@ -28,6 +28,7 @@ ASR_MODEL = os.getenv("ASR_MODEL", "whisper-1")
|
||||
TRANSLATE_MODEL = os.getenv("TRANSLATE_MODEL", "gemini-2.5-flash")
|
||||
REWRITE_MODEL = os.getenv("REWRITE_MODEL", "gemini-2.5-pro")
|
||||
VISION_MODEL = os.getenv("VISION_MODEL", "gemini-2.5-flash")
|
||||
IMAGE_MODEL = os.getenv("IMAGE_MODEL", "gemini-3-pro-image-preview")
|
||||
|
||||
# OpenAI 客户端(OpenAI 兼容网关,含 SKG ezlink)
|
||||
from openai import OpenAI
|
||||
@@ -52,11 +53,22 @@ JobStatus = Literal[
|
||||
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5"))
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
id: str # uuid hex 12
|
||||
prompt: str
|
||||
model: str
|
||||
mode: str = "edit" # "edit"(带参考图) | "text"(纯文字)
|
||||
url: str # /jobs/{job_id}/frames/{idx}/gen/{id}.jpg
|
||||
selected: bool = False
|
||||
created_at: float = 0.0
|
||||
|
||||
|
||||
class KeyFrame(BaseModel):
|
||||
index: int
|
||||
timestamp: float
|
||||
url: str
|
||||
description: dict | None = None # vision 模型识别结果 {scene, objects, style, suggested_prompt}
|
||||
generated_images: list[GeneratedImage] = []
|
||||
|
||||
|
||||
class TranscriptSegment(BaseModel):
|
||||
@@ -583,6 +595,134 @@ def get_frame(job_id: str, idx: int):
|
||||
return FileResponse(p, media_type="image/jpeg")
|
||||
|
||||
|
||||
class GenerateReq(BaseModel):
|
||||
prompt: str
|
||||
extra_prompt: str = ""
|
||||
model: str = "" # 留空用 IMAGE_MODEL 默认
|
||||
mode: str = "edit" # "edit" 带参考图,"text" 纯文字
|
||||
|
||||
|
||||
@app.post("/jobs/{job_id}/frames/{idx}/generate", response_model=Job)
|
||||
def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job:
|
||||
"""根据关键帧 + prompt 生成新图(image-to-image 或 text-to-image)"""
|
||||
job = JOBS.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(404, "job not found")
|
||||
frame = next((f for f in job.frames if f.index == idx), None)
|
||||
if not frame:
|
||||
raise HTTPException(404, "frame not found")
|
||||
frame_path = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
|
||||
if not frame_path.exists():
|
||||
raise HTTPException(404, "frame file missing")
|
||||
|
||||
full_prompt = req.prompt.strip()
|
||||
if req.extra_prompt.strip():
|
||||
full_prompt = f"{full_prompt}. {req.extra_prompt.strip()}"
|
||||
if not full_prompt:
|
||||
raise HTTPException(400, "prompt required")
|
||||
|
||||
model = req.model or IMAGE_MODEL
|
||||
gen_id = uuid.uuid4().hex[:12]
|
||||
|
||||
import base64 as b64lib
|
||||
import time as _time
|
||||
|
||||
try:
|
||||
if req.mode == "edit":
|
||||
# image-to-image:用 generations 端点带 image 参数
|
||||
img_b64 = b64lib.b64encode(frame_path.read_bytes()).decode("ascii")
|
||||
data_uri = f"data:image/jpeg;base64,{img_b64}"
|
||||
# OpenAI SDK 不直接支持 image 参数,用底层 httpx
|
||||
import httpx
|
||||
with httpx.Client(timeout=120) as client:
|
||||
r = client.post(
|
||||
f"{LLM_BASE_URL}/images/generations",
|
||||
headers={
|
||||
"Authorization": f"Bearer {LLM_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": full_prompt,
|
||||
"image": data_uri,
|
||||
"n": 1,
|
||||
},
|
||||
)
|
||||
r.raise_for_status()
|
||||
resp_data = r.json()
|
||||
else:
|
||||
# text-only
|
||||
resp = llm().images.generate(model=model, prompt=full_prompt, n=1)
|
||||
resp_data = resp.model_dump() if hasattr(resp, "model_dump") else {"data": [{"b64_json": resp.data[0].b64_json}]}
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise HTTPException(500, f"image gen HTTP {e.response.status_code}: {e.response.text[:300]}")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"image gen failed: {e}")
|
||||
|
||||
data_arr = resp_data.get("data", [])
|
||||
if not data_arr:
|
||||
raise HTTPException(500, "image gen returned no data")
|
||||
|
||||
item = data_arr[0]
|
||||
b64 = item.get("b64_json")
|
||||
if not b64:
|
||||
raise HTTPException(500, "image gen returned no b64_json")
|
||||
|
||||
# 保存到本地 jobs/<id>/gen/<idx>_<gen_id>.jpg
|
||||
gen_dir = job_dir(job_id) / "gen"
|
||||
gen_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_path = gen_dir / f"{idx:03d}_{gen_id}.jpg"
|
||||
out_path.write_bytes(b64lib.b64decode(b64))
|
||||
|
||||
new_gen = GeneratedImage(
|
||||
id=gen_id,
|
||||
prompt=full_prompt,
|
||||
model=model,
|
||||
mode=req.mode,
|
||||
url=f"/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg",
|
||||
selected=False,
|
||||
created_at=_time.time(),
|
||||
)
|
||||
|
||||
# 写回 job.frames
|
||||
for f in job.frames:
|
||||
if f.index == idx:
|
||||
f.generated_images = f.generated_images + [new_gen]
|
||||
update(job, frames=job.frames, message=f"生图完成 · 分镜 {idx + 1}")
|
||||
return job
|
||||
|
||||
|
||||
@app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg")
|
||||
def get_generated_image(job_id: str, idx: int, gen_id: str):
|
||||
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
|
||||
if not p.exists():
|
||||
raise HTTPException(404, "generated image not found")
|
||||
return FileResponse(p, media_type="image/jpeg")
|
||||
|
||||
|
||||
class SelectGenReq(BaseModel):
|
||||
selected: bool
|
||||
|
||||
|
||||
@app.post("/jobs/{job_id}/frames/{idx}/gen/{gen_id}/select", response_model=Job)
|
||||
def select_generated(job_id: str, idx: int, gen_id: str, req: SelectGenReq) -> Job:
|
||||
job = JOBS.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(404, "job not found")
|
||||
for f in job.frames:
|
||||
if f.index != idx:
|
||||
continue
|
||||
for g in f.generated_images:
|
||||
# 单选:该帧只能选一张
|
||||
if g.id == gen_id:
|
||||
g.selected = req.selected
|
||||
else:
|
||||
g.selected = False
|
||||
break
|
||||
update(job, frames=job.frames)
|
||||
return job
|
||||
|
||||
|
||||
@app.post("/jobs/{job_id}/frames/{idx}/describe", response_model=Job)
|
||||
def describe_frame(job_id: str, idx: int) -> Job:
|
||||
"""调 vision 模型识别该关键帧,返回结构化描述。"""
|
||||
|
||||
Reference in New Issue
Block a user