fix(api): harden subprocess/SSRF/concurrency and add db pool

- run(): add timeout (download 600s via DOWNLOAD_TIMEOUT_SECONDS, else 300s);
  TimeoutExpired now kills the child and fails the job instead of hanging forever
- create_job: validate_source_url() rejects file://, private/loopback/link-local
  IPs and off-allowlist hosts (SOURCE_URL_ALLOWED_HOSTS) — closes SSRF/local-read
- per-job RLock guards save_state/update/update_generated_video and the retry
  check-and-set so concurrent video workers can't clobber state.json
- db: psycopg_pool connection pool (graceful fallback if unavailable); write
  failures surfaced via logging.error instead of silent print
- read-only media GET routes use job_path() (no mkdir) to stop empty-dir spam
- wrap remaining Image.open() in with-blocks to avoid fd leaks

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-30 02:04:59 +08:00
parent 56ea8aef11
commit 3ed3f721eb
3 changed files with 191 additions and 62 deletions

View File

@@ -5,6 +5,7 @@ import base64
import hashlib
import hmac
import io
import ipaddress
import json
import os
import random
@@ -19,7 +20,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from urllib.parse import urlencode
from urllib.parse import urlencode, urlparse
import httpx
from dotenv import load_dotenv
@@ -289,6 +290,21 @@ PRODUCT_ASSET_JPEG_QUALITY = max(80, min(95, int(os.getenv("PRODUCT_ASSET_JPEG_Q
VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance"
YTDLP_COOKIES_FILE = os.getenv("YTDLP_COOKIES_FILE", "").strip()
YTDLP_COOKIES_FROM_BROWSER = os.getenv("YTDLP_COOKIES_FROM_BROWSER", "").strip()
# Max seconds a single yt-dlp source download may run before run() kills it.
DOWNLOAD_TIMEOUT_SECONDS = max(60, int(os.getenv("DOWNLOAD_TIMEOUT_SECONDS", "600")))
# SSRF guard for create_job: only fetch source videos from these host suffixes.
# Override/extend via SOURCE_URL_ALLOWED_HOSTS (comma-separated). file://, private
# IPs and metadata endpoints are always rejected regardless of this list.
_DEFAULT_SOURCE_HOSTS = (
"tiktok.com,douyin.com,ixigua.com,iesdouyin.com,"
"youtube.com,youtu.be,instagram.com,facebook.com,fb.watch,"
"bilibili.com,xiaohongshu.com,xhslink.com,kuaishou.com,v.kuaishou.com"
)
SOURCE_URL_ALLOWED_HOSTS = {
h.strip().lower()
for h in os.getenv("SOURCE_URL_ALLOWED_HOSTS", _DEFAULT_SOURCE_HOSTS).split(",")
if h.strip()
}
AUDIO_PRODUCT_BRIEF = os.getenv(
"AUDIO_PRODUCT_BRIEF",
"SKG smart massage products for everyday neck-and-shoulder, back, eye, knee, or foot relaxation. Ads should feel premium, clean, trustworthy, and must not make medical efficacy claims.",
@@ -1373,6 +1389,12 @@ def job_dir(job_id: str) -> Path:
return d
def job_path(job_id: str) -> Path:
"""Job dir path WITHOUT creating it — for read-only routes, so an unknown
job id can't be used to spam empty directories under JOBS_DIR."""
return JOBS_DIR / job_id
def source_audio_url_for(job_id: str) -> str:
return f"/jobs/{job_id}/audio.wav" if (JOBS_DIR / job_id / "audio.wav").exists() else ""
@@ -1384,16 +1406,34 @@ def job_with_artifacts(job: Job) -> Job:
return job.model_copy(update=updates)
_JOB_LOCKS_GUARD = threading.Lock()
_JOB_LOCKS: dict[str, threading.RLock] = {}
def _job_lock(job_id: str) -> threading.RLock:
"""Per-job re-entrant lock so concurrent video-queue workers / background
tasks can't lose each other's updates via interleaved read-modify-write on the
shared Job object and state.json (RLock allows update()->save_state() nesting)."""
with _JOB_LOCKS_GUARD:
lock = _JOB_LOCKS.get(job_id)
if lock is None:
lock = threading.RLock()
_JOB_LOCKS[job_id] = lock
return lock
def save_state(job: Job) -> None:
state_path = job_dir(job.id) / "state.json"
state_path.write_text(job.model_dump_json(indent=2))
db.index_job(job.model_dump(), str(state_path))
with _job_lock(job.id):
state_path = job_dir(job.id) / "state.json"
state_path.write_text(job.model_dump_json(indent=2))
db.index_job(job.model_dump(), str(state_path))
def update(job: Job, **kw) -> None:
for k, v in kw.items():
setattr(job, k, v)
save_state(job)
with _job_lock(job.id):
for k, v in kw.items():
setattr(job, k, v)
save_state(job)
def public_api_base() -> str:
@@ -1877,7 +1917,8 @@ def storyboard_ref_url(job_id: str, ref: dict | None) -> str:
def prepare_video_reference(src: Path, dst: Path, size: tuple[int, int] = (720, 1280)) -> None:
dst.parent.mkdir(parents=True, exist_ok=True)
img = Image.open(src).convert("RGB")
with Image.open(src) as _raw:
img = _raw.convert("RGB")
img.thumbnail(size, Image.Resampling.LANCZOS)
canvas = Image.new("RGB", size, (8, 8, 10))
x = (size[0] - img.width) // 2
@@ -1890,15 +1931,18 @@ def update_generated_video(job_id: str, video_id: str, **kw) -> None:
job = JOBS.get(job_id)
if not job:
return
updated = []
for v in job.generated_videos:
if v.id == video_id:
data = v.model_dump()
data.update(kw)
updated.append(GeneratedVideo(**data))
else:
updated.append(v)
update(job, generated_videos=updated)
# hold the job lock across the whole read-modify-write so concurrent video
# workers updating different videos on the same job don't clobber the list.
with _job_lock(job_id):
updated = []
for v in job.generated_videos:
if v.id == video_id:
data = v.model_dump()
data.update(kw)
updated.append(GeneratedVideo(**data))
else:
updated.append(v)
update(job, generated_videos=updated)
def generated_video_exists(job_id: str, video_id: str) -> bool:
@@ -2840,9 +2884,15 @@ def _normalize_media_cmd(cmd: list[str]) -> list[str]:
return cmd
def run(cmd: list[str], cwd: Path | None = None) -> str:
def run(cmd: list[str], cwd: Path | None = None, timeout: float | None = 300) -> str:
cmd = _normalize_media_cmd(cmd)
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
try:
# timeout guards against a stalled yt-dlp/ffmpeg/ffprobe hanging the worker
# forever (job stuck in downloading/splitting + leaked subprocess). On
# timeout subprocess.run kills the child before re-raising.
res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout)
except subprocess.TimeoutExpired as e:
raise RuntimeError(f"cmd timed out after {timeout}s: {' '.join(cmd[:3])}...") from e
if res.returncode != 0:
# ffmpeg 把 banner 写 stderr挑最后几行真错误一般在末尾
tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:]
@@ -2850,6 +2900,36 @@ def run(cmd: list[str], cwd: Path | None = None) -> str:
return res.stdout
def validate_source_url(url: str) -> str:
"""Reject SSRF / local-file vectors before a URL reaches yt-dlp.
yt-dlp will happily resolve file://, http://169.254.169.254/ (cloud metadata),
loopback and internal hosts, so a raw job URL is an SSRF / local-read primitive.
We require http(s), block private/loopback/link-local IP literals, and restrict
to a configurable allowlist of known short-video platforms.
"""
raw = (url or "").strip()
if not raw:
raise HTTPException(400, "url required")
if raw.startswith("upload://"):
return raw
parsed = urlparse(raw)
if parsed.scheme not in {"http", "https"}:
raise HTTPException(400, "仅支持 http/https 视频链接")
host = (parsed.hostname or "").lower()
if not host:
raise HTTPException(400, "链接缺少主机名")
try:
ip = ipaddress.ip_address(host)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved or ip.is_multicast or ip.is_unspecified:
raise HTTPException(400, "不允许访问内网 / 本地地址")
except ValueError:
pass # hostname, not a bare IP literal
if not any(host == suffix or host.endswith("." + suffix) for suffix in SOURCE_URL_ALLOWED_HOSTS):
raise HTTPException(400, "链接域名不在允许列表(仅支持主流短视频平台)")
return raw
def ytdlp_cookie_args() -> list[str]:
if YTDLP_COOKIES_FILE:
cookies = Path(YTDLP_COOKIES_FILE).expanduser()
@@ -3101,7 +3181,8 @@ def _focus_source_for_element(job_id: str, idx: int, el: KeyElement) -> tuple[Pa
if not el.region:
return model_src, tmp_focus
try:
im = Image.open(src).convert("RGB")
with Image.open(src) as _raw:
im = _raw.convert("RGB")
W, H = im.size
r = el.region
x = max(0.0, min(1.0, float(r.get("x", 0))))
@@ -3144,7 +3225,8 @@ def _make_reference_contact_sheet(job_id: str, frame_indices: list[int], out_pat
thumbs: list[Image.Image] = []
for p in paths:
try:
im = Image.open(p).convert("RGB")
with Image.open(p) as _raw:
im = _raw.convert("RGB")
im.thumbnail((420, 420), Image.Resampling.LANCZOS)
canvas = Image.new("RGB", (420, 420), (245, 245, 245))
canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2))
@@ -3182,7 +3264,8 @@ def _make_paths_contact_sheet(paths: list[Path], out_path: Path, max_items: int
thumbs: list[Image.Image] = []
for p in usable:
try:
im = Image.open(p).convert("RGB")
with Image.open(p) as _raw:
im = _raw.convert("RGB")
im.thumbnail((420, 420), Image.Resampling.LANCZOS)
canvas = Image.new("RGB", (420, 420), (245, 245, 245))
canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2))
@@ -3591,7 +3674,7 @@ def pipeline_download(job_id: str) -> None:
*ytdlp_cookie_args(),
job.url,
]
run(cmd)
run(cmd, timeout=DOWNLOAD_TIMEOUT_SECONDS)
if not mp4.exists():
raise RuntimeError("下载完成但找不到 source.mp4")
@@ -5182,12 +5265,12 @@ def _prepare_image_edit_bytes(image_path: Path, max_side: int) -> bytes:
import io as _io
from PIL import Image as _PILImage
try:
im = _PILImage.open(image_path)
if max(im.size) > max_side:
im.thumbnail((max_side, max_side), _PILImage.LANCZOS)
buf = _io.BytesIO()
im.convert("RGB").save(buf, format="JPEG", quality=88)
return buf.getvalue()
with _PILImage.open(image_path) as im:
if max(im.size) > max_side:
im.thumbnail((max_side, max_side), _PILImage.LANCZOS)
buf = _io.BytesIO()
im.convert("RGB").save(buf, format="JPEG", quality=88)
return buf.getvalue()
except Exception:
return image_path.read_bytes()
@@ -6564,11 +6647,10 @@ def list_jobs(request: Request, limit: int | None = None) -> list[JobSummary]:
@app.post("/jobs", response_model=Job)
async def create_job(req: CreateJobReq, bg: BackgroundTasks, request: Request) -> Job:
if not req.url.strip():
raise HTTPException(400, "url required")
safe_url = validate_source_url(req.url)
user = data_user_from_request(request)
job_id = uuid.uuid4().hex[:12]
job = Job(id=job_id, url=req.url.strip())
job = Job(id=job_id, url=safe_url)
assign_owner(job, user)
JOBS[job_id] = job
save_state(job)
@@ -6585,20 +6667,23 @@ async def retry_job_download(job_id: str, bg: BackgroundTasks) -> Job:
source_kind = getattr(job, "source_kind", "")
if source_kind == "upload" or job.url.startswith("upload://"):
raise HTTPException(409, "uploaded videos cannot be redownloaded; upload the file again")
if job.status in {"downloading", "splitting", "transcribing"}:
raise HTTPException(409, f"job is busy: {job.status}")
mp4 = job_dir(job_id) / "source.mp4"
if mp4.exists() and mp4.stat().st_size == 0:
mp4.unlink()
update(
job,
status="downloading",
progress=1,
error="",
message="重新提交下载…",
video_url="",
)
# atomic check-and-set: two concurrent retries (or a retry racing the live
# pipeline) must not both launch yt-dlp against the same source.mp4.
with _job_lock(job_id):
if job.status in {"downloading", "splitting", "transcribing"}:
raise HTTPException(409, f"job is busy: {job.status}")
mp4 = job_dir(job_id) / "source.mp4"
if mp4.exists() and mp4.stat().st_size == 0:
mp4.unlink()
update(
job,
status="downloading",
progress=1,
error="",
message="重新提交下载…",
video_url="",
)
bg.add_task(pipeline_download, job_id)
return job
@@ -6933,7 +7018,7 @@ async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job:
@app.get("/jobs/{job_id}/video.mp4")
def get_video(job_id: str):
p = job_dir(job_id) / "source.mp4"
p = job_path(job_id) / "source.mp4"
if not p.exists():
raise HTTPException(404, "video not found")
return FileResponse(p, media_type="video/mp4")
@@ -6941,7 +7026,7 @@ def get_video(job_id: str):
@app.get("/jobs/{job_id}/audio.wav")
def get_source_audio(job_id: str):
p = job_dir(job_id) / "audio.wav"
p = job_path(job_id) / "audio.wav"
if not p.exists():
raise HTTPException(404, "audio not found")
return FileResponse(p, media_type="audio/wav")
@@ -6949,7 +7034,7 @@ def get_source_audio(job_id: str):
@app.get("/jobs/{job_id}/audio-script.mp3")
def get_audio_script(job_id: str):
p = job_dir(job_id) / "audio_script.mp3"
p = job_path(job_id) / "audio_script.mp3"
if not p.exists():
raise HTTPException(404, "audio script not found")
return FileResponse(p, media_type="audio/mpeg")
@@ -6957,7 +7042,7 @@ def get_audio_script(job_id: str):
@app.get("/jobs/{job_id}/frames/{idx}.jpg")
def get_frame(job_id: str, idx: int):
p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg"
p = job_path(job_id) / "frames" / f"{idx:03d}.jpg"
if not p.exists():
raise HTTPException(404, "frame not found")
return FileResponse(p, media_type="image/jpeg")
@@ -7060,7 +7145,7 @@ def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job:
@app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg")
def get_generated_image(job_id: str, idx: int, gen_id: str):
p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
p = job_path(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg"
if not p.exists():
raise HTTPException(404, "generated image not found")
return FileResponse(p, media_type="image/jpeg")
@@ -7263,7 +7348,7 @@ def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job:
@app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg")
def get_cleaned_frame(job_id: str, idx: int):
p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg"
p = job_path(job_id) / "cleaned" / f"{idx:03d}.jpg"
if not p.exists():
raise HTTPException(404, "cleaned frame not found")
return FileResponse(p, media_type="image/jpeg")
@@ -7793,7 +7878,8 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job:
model_src = src
if el.region:
try:
im = _PILImage.open(src).convert("RGB")
with _PILImage.open(src) as _raw:
im = _raw.convert("RGB")
W, H = im.size
r = el.region
x = max(0.0, min(1.0, float(r.get("x", 0))))
@@ -9352,7 +9438,7 @@ def batch_generate_all_storyboard(job_id: str, req: BatchGenerateStoryboardReq)
@app.get("/jobs/{job_id}/storyboard-videos/{video_id}.mp4")
def get_storyboard_video(job_id: str, video_id: str):
p = job_dir(job_id) / "storyboard_videos" / video_id / "video.mp4"
p = job_path(job_id) / "storyboard_videos" / video_id / "video.mp4"
if not p.exists():
raise HTTPException(404, "storyboard video not found")
return FileResponse(p, media_type="video/mp4")
@@ -10115,7 +10201,8 @@ def copy_character_library_assets(job_id: str, req: CopyCharacterLibraryAssetReq
asset_id = uuid.uuid4().hex[:12]
out = out_dir / f"{asset_id}.jpg"
try:
img = Image.open(src).convert("RGB")
with Image.open(src) as _raw:
img = _raw.convert("RGB")
img.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
img.save(out, "JPEG", quality=94)
except Exception as e:
@@ -10544,7 +10631,7 @@ def get_agent_run(run_id: str) -> AgentRun:
@app.get("/agent-runs/{run_id}/final.mp4")
def get_agent_run_final(run_id: str):
run = get_agent_run(run_id)
p = job_dir(run.job_id) / "final" / f"agent-{run.id}.mp4"
p = job_path(run.job_id) / "final" / f"agent-{run.id}.mp4"
if not p.exists():
raise HTTPException(404, "final video not found")
return FileResponse(p, media_type="video/mp4")
@@ -10588,9 +10675,11 @@ def create_product_fusion_guide(job_id: str, req: ProductFusionShot) -> dict:
h = max(0.02, min(1.0 - y, float(region.h)))
try:
base = Image.open(person_path).convert("RGB")
with Image.open(person_path) as _raw:
base = _raw.convert("RGB")
base.thumbnail((1600, 1600), Image.Resampling.LANCZOS)
product = product_image_alpha(Image.open(product_path))
with Image.open(product_path) as _praw:
product = product_image_alpha(_praw)
bw, bh = base.size
box = (
int(round(x * bw)),
@@ -10693,7 +10782,7 @@ def generate_product_fusion_descriptions(job_id: str, req: ProductFusionDescript
@app.get("/jobs/{job_id}/assets/{asset_id}.jpg")
def get_storyboard_asset(job_id: str, asset_id: str):
p = job_dir(job_id) / "assets" / f"{asset_id}.jpg"
p = job_path(job_id) / "assets" / f"{asset_id}.jpg"
if not p.exists():
raise HTTPException(404, "asset not found")
return FileResponse(p, media_type="image/jpeg")
@@ -10826,7 +10915,7 @@ def remove_storyboard_image(job_id: str, ref_id: str) -> Job:
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}.jpg")
def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str):
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
p = job_path(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg"
if not p.exists():
raise HTTPException(404, "cutout not found")
return FileResponse(p, media_type="image/jpeg")
@@ -10835,9 +10924,9 @@ def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str)
@app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg")
def get_cutout(job_id: str, idx: int, element_id: str):
"""旧路径兼容v1 单图)→ 找 elements/{idx}_{element_id}.jpg 或 .png"""
p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg"
p = job_path(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg"
if not p.exists():
legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png"
legacy = job_path(job_id) / "elements" / f"{idx:03d}_{element_id}.png"
if legacy.exists():
return FileResponse(legacy, media_type="image/jpeg")
raise HTTPException(404, "cutout not found")