From 3ed3f721eba6406d727d4bd9868754602e8ed702 Mon Sep 17 00:00:00 2001 From: kang Date: Sat, 30 May 2026 02:04:59 +0800 Subject: [PATCH] fix(api): harden subprocess/SSRF/concurrency and add db pool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - run(): add timeout (download 600s via DOWNLOAD_TIMEOUT_SECONDS, else 300s); TimeoutExpired now kills the child and fails the job instead of hanging forever - create_job: validate_source_url() rejects file://, private/loopback/link-local IPs and off-allowlist hosts (SOURCE_URL_ALLOWED_HOSTS) — closes SSRF/local-read - per-job RLock guards save_state/update/update_generated_video and the retry check-and-set so concurrent video workers can't clobber state.json - db: psycopg_pool connection pool (graceful fallback if unavailable); write failures surfaced via logging.error instead of silent print - read-only media GET routes use job_path() (no mkdir) to stop empty-dir spam - wrap remaining Image.open() in with-blocks to avoid fd leaks Co-Authored-By: Claude Opus 4.8 (1M context) --- api/db.py | 41 ++++++++- api/main.py | 211 ++++++++++++++++++++++++++++++------------- api/requirements.txt | 1 + 3 files changed, 191 insertions(+), 62 deletions(-) diff --git a/api/db.py b/api/db.py index 04ce407..d599ebf 100644 --- a/api/db.py +++ b/api/db.py @@ -1,6 +1,8 @@ from __future__ import annotations +import logging import os +import threading import time import uuid from datetime import datetime, timezone @@ -15,18 +17,53 @@ except ModuleNotFoundError: # Local dev can still run without Postgres deps ins dict_row = None Jsonb = None +try: + from psycopg_pool import ConnectionPool +except ModuleNotFoundError: # Pool is optional; fall back to per-call connections. + ConnectionPool = None + + +logger = logging.getLogger("skg.db") DATABASE_URL = os.getenv("DATABASE_URL", "").strip() DB_ENABLED = bool(DATABASE_URL and psycopg is not None) +_POOL = None +_POOL_LOCK = threading.Lock() + def enabled() -> bool: return DB_ENABLED +def _pool(): + """Lazily build a process-wide connection pool so concurrent workers/requests + don't exhaust Postgres by opening a fresh connection per query.""" + global _POOL + if _POOL is not None: + return _POOL + with _POOL_LOCK: + if _POOL is None: + pool = ConnectionPool( + DATABASE_URL, + min_size=1, + max_size=int(os.getenv("DB_POOL_MAX_SIZE", "10")), + timeout=10, + kwargs={"row_factory": dict_row, "connect_timeout": 5}, + open=False, + ) + pool.open() + _POOL = pool + return _POOL + + def _connect(): if not DB_ENABLED: raise RuntimeError("database disabled") + if ConnectionPool is not None: + # pool.connection() is a context manager that returns the conn to the + # pool on exit, matching the existing `with _connect() as conn:` callers. + return _pool().connection() return psycopg.connect(DATABASE_URL, row_factory=dict_row, connect_timeout=5) @@ -45,12 +82,14 @@ def _json(value: Any): def _execute_safely(label: str, fn): + # DB disabled is an expected, silent no-op; an actual failure while the DB is + # enabled is a real problem (stale job index / dropped audit) and must be loud. if not DB_ENABLED: return None try: return fn() except Exception as exc: - print(f"[db] {label} failed: {exc}", flush=True) + logger.error("[db] %s failed: %s", label, exc) return None diff --git a/api/main.py b/api/main.py index 66c523c..4c3312b 100644 --- a/api/main.py +++ b/api/main.py @@ -5,6 +5,7 @@ import base64 import hashlib import hmac import io +import ipaddress import json import os import random @@ -19,7 +20,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path from typing import Literal -from urllib.parse import urlencode +from urllib.parse import urlencode, urlparse import httpx from dotenv import load_dotenv @@ -289,6 +290,21 @@ PRODUCT_ASSET_JPEG_QUALITY = max(80, min(95, int(os.getenv("PRODUCT_ASSET_JPEG_Q VIDEO_MODEL = os.getenv("VIDEO_MODEL", "seedance").strip() or "seedance" YTDLP_COOKIES_FILE = os.getenv("YTDLP_COOKIES_FILE", "").strip() YTDLP_COOKIES_FROM_BROWSER = os.getenv("YTDLP_COOKIES_FROM_BROWSER", "").strip() +# Max seconds a single yt-dlp source download may run before run() kills it. +DOWNLOAD_TIMEOUT_SECONDS = max(60, int(os.getenv("DOWNLOAD_TIMEOUT_SECONDS", "600"))) +# SSRF guard for create_job: only fetch source videos from these host suffixes. +# Override/extend via SOURCE_URL_ALLOWED_HOSTS (comma-separated). file://, private +# IPs and metadata endpoints are always rejected regardless of this list. +_DEFAULT_SOURCE_HOSTS = ( + "tiktok.com,douyin.com,ixigua.com,iesdouyin.com," + "youtube.com,youtu.be,instagram.com,facebook.com,fb.watch," + "bilibili.com,xiaohongshu.com,xhslink.com,kuaishou.com,v.kuaishou.com" +) +SOURCE_URL_ALLOWED_HOSTS = { + h.strip().lower() + for h in os.getenv("SOURCE_URL_ALLOWED_HOSTS", _DEFAULT_SOURCE_HOSTS).split(",") + if h.strip() +} AUDIO_PRODUCT_BRIEF = os.getenv( "AUDIO_PRODUCT_BRIEF", "SKG smart massage products for everyday neck-and-shoulder, back, eye, knee, or foot relaxation. Ads should feel premium, clean, trustworthy, and must not make medical efficacy claims.", @@ -1373,6 +1389,12 @@ def job_dir(job_id: str) -> Path: return d +def job_path(job_id: str) -> Path: + """Job dir path WITHOUT creating it — for read-only routes, so an unknown + job id can't be used to spam empty directories under JOBS_DIR.""" + return JOBS_DIR / job_id + + def source_audio_url_for(job_id: str) -> str: return f"/jobs/{job_id}/audio.wav" if (JOBS_DIR / job_id / "audio.wav").exists() else "" @@ -1384,16 +1406,34 @@ def job_with_artifacts(job: Job) -> Job: return job.model_copy(update=updates) +_JOB_LOCKS_GUARD = threading.Lock() +_JOB_LOCKS: dict[str, threading.RLock] = {} + + +def _job_lock(job_id: str) -> threading.RLock: + """Per-job re-entrant lock so concurrent video-queue workers / background + tasks can't lose each other's updates via interleaved read-modify-write on the + shared Job object and state.json (RLock allows update()->save_state() nesting).""" + with _JOB_LOCKS_GUARD: + lock = _JOB_LOCKS.get(job_id) + if lock is None: + lock = threading.RLock() + _JOB_LOCKS[job_id] = lock + return lock + + def save_state(job: Job) -> None: - state_path = job_dir(job.id) / "state.json" - state_path.write_text(job.model_dump_json(indent=2)) - db.index_job(job.model_dump(), str(state_path)) + with _job_lock(job.id): + state_path = job_dir(job.id) / "state.json" + state_path.write_text(job.model_dump_json(indent=2)) + db.index_job(job.model_dump(), str(state_path)) def update(job: Job, **kw) -> None: - for k, v in kw.items(): - setattr(job, k, v) - save_state(job) + with _job_lock(job.id): + for k, v in kw.items(): + setattr(job, k, v) + save_state(job) def public_api_base() -> str: @@ -1877,7 +1917,8 @@ def storyboard_ref_url(job_id: str, ref: dict | None) -> str: def prepare_video_reference(src: Path, dst: Path, size: tuple[int, int] = (720, 1280)) -> None: dst.parent.mkdir(parents=True, exist_ok=True) - img = Image.open(src).convert("RGB") + with Image.open(src) as _raw: + img = _raw.convert("RGB") img.thumbnail(size, Image.Resampling.LANCZOS) canvas = Image.new("RGB", size, (8, 8, 10)) x = (size[0] - img.width) // 2 @@ -1890,15 +1931,18 @@ def update_generated_video(job_id: str, video_id: str, **kw) -> None: job = JOBS.get(job_id) if not job: return - updated = [] - for v in job.generated_videos: - if v.id == video_id: - data = v.model_dump() - data.update(kw) - updated.append(GeneratedVideo(**data)) - else: - updated.append(v) - update(job, generated_videos=updated) + # hold the job lock across the whole read-modify-write so concurrent video + # workers updating different videos on the same job don't clobber the list. + with _job_lock(job_id): + updated = [] + for v in job.generated_videos: + if v.id == video_id: + data = v.model_dump() + data.update(kw) + updated.append(GeneratedVideo(**data)) + else: + updated.append(v) + update(job, generated_videos=updated) def generated_video_exists(job_id: str, video_id: str) -> bool: @@ -2840,9 +2884,15 @@ def _normalize_media_cmd(cmd: list[str]) -> list[str]: return cmd -def run(cmd: list[str], cwd: Path | None = None) -> str: +def run(cmd: list[str], cwd: Path | None = None, timeout: float | None = 300) -> str: cmd = _normalize_media_cmd(cmd) - res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True) + try: + # timeout guards against a stalled yt-dlp/ffmpeg/ffprobe hanging the worker + # forever (job stuck in downloading/splitting + leaked subprocess). On + # timeout subprocess.run kills the child before re-raising. + res = subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, timeout=timeout) + except subprocess.TimeoutExpired as e: + raise RuntimeError(f"cmd timed out after {timeout}s: {' '.join(cmd[:3])}...") from e if res.returncode != 0: # ffmpeg 把 banner 写 stderr,挑最后几行(真错误一般在末尾) tail = "\n".join(res.stderr.splitlines()[-12:]) or res.stderr[-800:] @@ -2850,6 +2900,36 @@ def run(cmd: list[str], cwd: Path | None = None) -> str: return res.stdout +def validate_source_url(url: str) -> str: + """Reject SSRF / local-file vectors before a URL reaches yt-dlp. + + yt-dlp will happily resolve file://, http://169.254.169.254/ (cloud metadata), + loopback and internal hosts, so a raw job URL is an SSRF / local-read primitive. + We require http(s), block private/loopback/link-local IP literals, and restrict + to a configurable allowlist of known short-video platforms. + """ + raw = (url or "").strip() + if not raw: + raise HTTPException(400, "url required") + if raw.startswith("upload://"): + return raw + parsed = urlparse(raw) + if parsed.scheme not in {"http", "https"}: + raise HTTPException(400, "仅支持 http/https 视频链接") + host = (parsed.hostname or "").lower() + if not host: + raise HTTPException(400, "链接缺少主机名") + try: + ip = ipaddress.ip_address(host) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved or ip.is_multicast or ip.is_unspecified: + raise HTTPException(400, "不允许访问内网 / 本地地址") + except ValueError: + pass # hostname, not a bare IP literal + if not any(host == suffix or host.endswith("." + suffix) for suffix in SOURCE_URL_ALLOWED_HOSTS): + raise HTTPException(400, "链接域名不在允许列表(仅支持主流短视频平台)") + return raw + + def ytdlp_cookie_args() -> list[str]: if YTDLP_COOKIES_FILE: cookies = Path(YTDLP_COOKIES_FILE).expanduser() @@ -3101,7 +3181,8 @@ def _focus_source_for_element(job_id: str, idx: int, el: KeyElement) -> tuple[Pa if not el.region: return model_src, tmp_focus try: - im = Image.open(src).convert("RGB") + with Image.open(src) as _raw: + im = _raw.convert("RGB") W, H = im.size r = el.region x = max(0.0, min(1.0, float(r.get("x", 0)))) @@ -3144,7 +3225,8 @@ def _make_reference_contact_sheet(job_id: str, frame_indices: list[int], out_pat thumbs: list[Image.Image] = [] for p in paths: try: - im = Image.open(p).convert("RGB") + with Image.open(p) as _raw: + im = _raw.convert("RGB") im.thumbnail((420, 420), Image.Resampling.LANCZOS) canvas = Image.new("RGB", (420, 420), (245, 245, 245)) canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2)) @@ -3182,7 +3264,8 @@ def _make_paths_contact_sheet(paths: list[Path], out_path: Path, max_items: int thumbs: list[Image.Image] = [] for p in usable: try: - im = Image.open(p).convert("RGB") + with Image.open(p) as _raw: + im = _raw.convert("RGB") im.thumbnail((420, 420), Image.Resampling.LANCZOS) canvas = Image.new("RGB", (420, 420), (245, 245, 245)) canvas.paste(im, ((420 - im.width) // 2, (420 - im.height) // 2)) @@ -3591,7 +3674,7 @@ def pipeline_download(job_id: str) -> None: *ytdlp_cookie_args(), job.url, ] - run(cmd) + run(cmd, timeout=DOWNLOAD_TIMEOUT_SECONDS) if not mp4.exists(): raise RuntimeError("下载完成但找不到 source.mp4") @@ -5182,12 +5265,12 @@ def _prepare_image_edit_bytes(image_path: Path, max_side: int) -> bytes: import io as _io from PIL import Image as _PILImage try: - im = _PILImage.open(image_path) - if max(im.size) > max_side: - im.thumbnail((max_side, max_side), _PILImage.LANCZOS) - buf = _io.BytesIO() - im.convert("RGB").save(buf, format="JPEG", quality=88) - return buf.getvalue() + with _PILImage.open(image_path) as im: + if max(im.size) > max_side: + im.thumbnail((max_side, max_side), _PILImage.LANCZOS) + buf = _io.BytesIO() + im.convert("RGB").save(buf, format="JPEG", quality=88) + return buf.getvalue() except Exception: return image_path.read_bytes() @@ -6564,11 +6647,10 @@ def list_jobs(request: Request, limit: int | None = None) -> list[JobSummary]: @app.post("/jobs", response_model=Job) async def create_job(req: CreateJobReq, bg: BackgroundTasks, request: Request) -> Job: - if not req.url.strip(): - raise HTTPException(400, "url required") + safe_url = validate_source_url(req.url) user = data_user_from_request(request) job_id = uuid.uuid4().hex[:12] - job = Job(id=job_id, url=req.url.strip()) + job = Job(id=job_id, url=safe_url) assign_owner(job, user) JOBS[job_id] = job save_state(job) @@ -6585,20 +6667,23 @@ async def retry_job_download(job_id: str, bg: BackgroundTasks) -> Job: source_kind = getattr(job, "source_kind", "") if source_kind == "upload" or job.url.startswith("upload://"): raise HTTPException(409, "uploaded videos cannot be redownloaded; upload the file again") - if job.status in {"downloading", "splitting", "transcribing"}: - raise HTTPException(409, f"job is busy: {job.status}") - mp4 = job_dir(job_id) / "source.mp4" - if mp4.exists() and mp4.stat().st_size == 0: - mp4.unlink() - update( - job, - status="downloading", - progress=1, - error="", - message="重新提交下载…", - video_url="", - ) + # atomic check-and-set: two concurrent retries (or a retry racing the live + # pipeline) must not both launch yt-dlp against the same source.mp4. + with _job_lock(job_id): + if job.status in {"downloading", "splitting", "transcribing"}: + raise HTTPException(409, f"job is busy: {job.status}") + mp4 = job_dir(job_id) / "source.mp4" + if mp4.exists() and mp4.stat().st_size == 0: + mp4.unlink() + update( + job, + status="downloading", + progress=1, + error="", + message="重新提交下载…", + video_url="", + ) bg.add_task(pipeline_download, job_id) return job @@ -6933,7 +7018,7 @@ async def trigger_transcribe(job_id: str, bg: BackgroundTasks) -> Job: @app.get("/jobs/{job_id}/video.mp4") def get_video(job_id: str): - p = job_dir(job_id) / "source.mp4" + p = job_path(job_id) / "source.mp4" if not p.exists(): raise HTTPException(404, "video not found") return FileResponse(p, media_type="video/mp4") @@ -6941,7 +7026,7 @@ def get_video(job_id: str): @app.get("/jobs/{job_id}/audio.wav") def get_source_audio(job_id: str): - p = job_dir(job_id) / "audio.wav" + p = job_path(job_id) / "audio.wav" if not p.exists(): raise HTTPException(404, "audio not found") return FileResponse(p, media_type="audio/wav") @@ -6949,7 +7034,7 @@ def get_source_audio(job_id: str): @app.get("/jobs/{job_id}/audio-script.mp3") def get_audio_script(job_id: str): - p = job_dir(job_id) / "audio_script.mp3" + p = job_path(job_id) / "audio_script.mp3" if not p.exists(): raise HTTPException(404, "audio script not found") return FileResponse(p, media_type="audio/mpeg") @@ -6957,7 +7042,7 @@ def get_audio_script(job_id: str): @app.get("/jobs/{job_id}/frames/{idx}.jpg") def get_frame(job_id: str, idx: int): - p = job_dir(job_id) / "frames" / f"{idx:03d}.jpg" + p = job_path(job_id) / "frames" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "frame not found") return FileResponse(p, media_type="image/jpeg") @@ -7060,7 +7145,7 @@ def generate_image(job_id: str, idx: int, req: GenerateReq) -> Job: @app.get("/jobs/{job_id}/frames/{idx}/gen/{gen_id}.jpg") def get_generated_image(job_id: str, idx: int, gen_id: str): - p = job_dir(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg" + p = job_path(job_id) / "gen" / f"{idx:03d}_{gen_id}.jpg" if not p.exists(): raise HTTPException(404, "generated image not found") return FileResponse(p, media_type="image/jpeg") @@ -7263,7 +7348,7 @@ def cleanup_frame(job_id: str, idx: int, req: CleanupReq | None = None) -> Job: @app.get("/jobs/{job_id}/frames/{idx}/cleaned.jpg") def get_cleaned_frame(job_id: str, idx: int): - p = job_dir(job_id) / "cleaned" / f"{idx:03d}.jpg" + p = job_path(job_id) / "cleaned" / f"{idx:03d}.jpg" if not p.exists(): raise HTTPException(404, "cleaned frame not found") return FileResponse(p, media_type="image/jpeg") @@ -7793,7 +7878,8 @@ def cutout_element(job_id: str, idx: int, element_id: str) -> Job: model_src = src if el.region: try: - im = _PILImage.open(src).convert("RGB") + with _PILImage.open(src) as _raw: + im = _raw.convert("RGB") W, H = im.size r = el.region x = max(0.0, min(1.0, float(r.get("x", 0)))) @@ -9352,7 +9438,7 @@ def batch_generate_all_storyboard(job_id: str, req: BatchGenerateStoryboardReq) @app.get("/jobs/{job_id}/storyboard-videos/{video_id}.mp4") def get_storyboard_video(job_id: str, video_id: str): - p = job_dir(job_id) / "storyboard_videos" / video_id / "video.mp4" + p = job_path(job_id) / "storyboard_videos" / video_id / "video.mp4" if not p.exists(): raise HTTPException(404, "storyboard video not found") return FileResponse(p, media_type="video/mp4") @@ -10115,7 +10201,8 @@ def copy_character_library_assets(job_id: str, req: CopyCharacterLibraryAssetReq asset_id = uuid.uuid4().hex[:12] out = out_dir / f"{asset_id}.jpg" try: - img = Image.open(src).convert("RGB") + with Image.open(src) as _raw: + img = _raw.convert("RGB") img.thumbnail((1600, 1600), Image.Resampling.LANCZOS) img.save(out, "JPEG", quality=94) except Exception as e: @@ -10544,7 +10631,7 @@ def get_agent_run(run_id: str) -> AgentRun: @app.get("/agent-runs/{run_id}/final.mp4") def get_agent_run_final(run_id: str): run = get_agent_run(run_id) - p = job_dir(run.job_id) / "final" / f"agent-{run.id}.mp4" + p = job_path(run.job_id) / "final" / f"agent-{run.id}.mp4" if not p.exists(): raise HTTPException(404, "final video not found") return FileResponse(p, media_type="video/mp4") @@ -10588,9 +10675,11 @@ def create_product_fusion_guide(job_id: str, req: ProductFusionShot) -> dict: h = max(0.02, min(1.0 - y, float(region.h))) try: - base = Image.open(person_path).convert("RGB") + with Image.open(person_path) as _raw: + base = _raw.convert("RGB") base.thumbnail((1600, 1600), Image.Resampling.LANCZOS) - product = product_image_alpha(Image.open(product_path)) + with Image.open(product_path) as _praw: + product = product_image_alpha(_praw) bw, bh = base.size box = ( int(round(x * bw)), @@ -10693,7 +10782,7 @@ def generate_product_fusion_descriptions(job_id: str, req: ProductFusionDescript @app.get("/jobs/{job_id}/assets/{asset_id}.jpg") def get_storyboard_asset(job_id: str, asset_id: str): - p = job_dir(job_id) / "assets" / f"{asset_id}.jpg" + p = job_path(job_id) / "assets" / f"{asset_id}.jpg" if not p.exists(): raise HTTPException(404, "asset not found") return FileResponse(p, media_type="image/jpeg") @@ -10826,7 +10915,7 @@ def remove_storyboard_image(job_id: str, ref_id: str) -> Job: @app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutouts/{cutout_id}.jpg") def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str): - p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg" + p = job_path(job_id) / "elements" / f"{idx:03d}_{element_id}_{cutout_id}.jpg" if not p.exists(): raise HTTPException(404, "cutout not found") return FileResponse(p, media_type="image/jpeg") @@ -10835,9 +10924,9 @@ def get_cutout_versioned(job_id: str, idx: int, element_id: str, cutout_id: str) @app.get("/jobs/{job_id}/frames/{idx}/elements/{element_id}/cutout.jpg") def get_cutout(job_id: str, idx: int, element_id: str): """旧路径兼容(v1 单图)→ 找 elements/{idx}_{element_id}.jpg 或 .png""" - p = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg" + p = job_path(job_id) / "elements" / f"{idx:03d}_{element_id}.jpg" if not p.exists(): - legacy = job_dir(job_id) / "elements" / f"{idx:03d}_{element_id}.png" + legacy = job_path(job_id) / "elements" / f"{idx:03d}_{element_id}.png" if legacy.exists(): return FileResponse(legacy, media_type="image/jpeg") raise HTTPException(404, "cutout not found") diff --git a/api/requirements.txt b/api/requirements.txt index b615890..959c630 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -8,6 +8,7 @@ openai==1.55.3 httpx==0.27.2 requests==2.32.5 psycopg[binary]==3.2.3 +psycopg-pool==3.2.4 imagehash==4.3.1 Pillow>=11.0 numpy>=2.0