auto-save 2026-05-14 03:53 (~5)
This commit is contained in:
226
api/main.py
226
api/main.py
@@ -88,6 +88,14 @@ JobStatus = Literal[
|
||||
]
|
||||
|
||||
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5"))
|
||||
FrameExtractTarget = Literal["balanced", "subject", "transition", "expression", "motion"]
|
||||
FRAME_TARGET_LABELS: dict[FrameExtractTarget, str] = {
|
||||
"balanced": "综合关键帧",
|
||||
"subject": "清晰主体",
|
||||
"transition": "转场变化",
|
||||
"expression": "表情瞬间",
|
||||
"motion": "动作峰值",
|
||||
}
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
@@ -383,37 +391,115 @@ import numpy as np
|
||||
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
|
||||
|
||||
|
||||
def _sharpness(img_path: Path) -> float:
|
||||
def _sharpness_from_gray(g: np.ndarray) -> float:
|
||||
"""Laplacian variance:值越大越清晰,模糊/转场帧值低。"""
|
||||
g = np.asarray(Image.open(img_path).convert("L").resize((320, 180)), dtype=np.float32)
|
||||
lap = (-4 * g[1:-1, 1:-1]
|
||||
+ g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:])
|
||||
return float(lap.var())
|
||||
|
||||
|
||||
def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) -> list[Path]:
|
||||
def _frame_metrics(img_path: Path, idx: int, timestamp: float) -> dict | None:
|
||||
"""低清候选帧的本地评分特征。只用于排序,最终仍从原视频抽原尺寸帧。"""
|
||||
try:
|
||||
with Image.open(img_path) as raw:
|
||||
img = raw.convert("RGB")
|
||||
h = imagehash.phash(img)
|
||||
small = img.resize((160, 90))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
arr = np.asarray(small, dtype=np.float32)
|
||||
# Rec. 601 luma,保留 0-255 范围,便于和清晰度 / 对比度阈值一起看。
|
||||
gray = (0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2]).astype(np.float32)
|
||||
center = gray[22:68, 40:120]
|
||||
rg = arr[:, :, 0] - arr[:, :, 1]
|
||||
yb = 0.5 * (arr[:, :, 0] + arr[:, :, 1]) - arr[:, :, 2]
|
||||
colorfulness = float(np.sqrt(rg.var() + yb.var()) + 0.3 * np.sqrt(rg.mean() ** 2 + yb.mean() ** 2))
|
||||
return {
|
||||
"path": img_path,
|
||||
"idx": idx,
|
||||
"timestamp": timestamp,
|
||||
"hash": h,
|
||||
"gray": gray,
|
||||
"sharp": _sharpness_from_gray(gray),
|
||||
"center_sharp": _sharpness_from_gray(center),
|
||||
"brightness": float(gray.mean()),
|
||||
"contrast": float(gray.std()),
|
||||
"colorfulness": colorfulness,
|
||||
"scene_score": 0.0,
|
||||
"motion": 0.0,
|
||||
}
|
||||
|
||||
|
||||
def _attach_temporal_metrics(items: list[dict]) -> None:
|
||||
"""相邻低清帧差异:转场 / 动作目标依赖它,不需要逐帧高分辨率扫描。"""
|
||||
for i, it in enumerate(items):
|
||||
prev_delta = 0.0
|
||||
next_delta = 0.0
|
||||
if i > 0:
|
||||
prev_delta = float(np.mean(np.abs(it["gray"] - items[i - 1]["gray"])) / 255.0)
|
||||
if i + 1 < len(items):
|
||||
next_delta = float(np.mean(np.abs(items[i + 1]["gray"] - it["gray"])) / 255.0)
|
||||
it["scene_score"] = max(prev_delta, next_delta)
|
||||
it["motion"] = (prev_delta + next_delta) / 2.0
|
||||
|
||||
|
||||
def _normalize_item_metrics(items: list[dict]) -> None:
|
||||
for key in ("sharp", "center_sharp", "contrast", "colorfulness", "scene_score", "motion"):
|
||||
vals = [float(it.get(key, 0.0)) for it in items if float(it.get(key, 0.0)) > 0]
|
||||
cap = float(np.percentile(vals, 95)) if vals else 1.0
|
||||
if cap <= 0:
|
||||
cap = 1.0
|
||||
for it in items:
|
||||
it[f"{key}_n"] = min(float(it.get(key, 0.0)) / cap, 1.0)
|
||||
|
||||
|
||||
def _target_score(item: dict, target: FrameExtractTarget) -> float:
|
||||
sharp = float(item.get("sharp_n", 0.0))
|
||||
center = float(item.get("center_sharp_n", 0.0))
|
||||
contrast = float(item.get("contrast_n", 0.0))
|
||||
color = float(item.get("colorfulness_n", 0.0))
|
||||
scene = float(item.get("scene_score_n", 0.0))
|
||||
motion = float(item.get("motion_n", 0.0))
|
||||
|
||||
if target == "subject":
|
||||
score = center * 0.48 + sharp * 0.25 + contrast * 0.17 + color * 0.10
|
||||
elif target == "transition":
|
||||
score = scene * 0.55 + sharp * 0.28 + contrast * 0.12 + color * 0.05
|
||||
elif target == "expression":
|
||||
# 没有额外视觉模型时,表情/动物瞬间只能用中心细节 + 清晰 + 轻微动作变化做本地近似。
|
||||
score = center * 0.40 + sharp * 0.24 + motion * 0.18 + contrast * 0.12 + color * 0.06
|
||||
elif target == "motion":
|
||||
score = motion * 0.45 + sharp * 0.30 + center * 0.15 + contrast * 0.10
|
||||
else:
|
||||
score = sharp * 0.45 + scene * 0.22 + center * 0.15 + contrast * 0.12 + color * 0.06
|
||||
|
||||
brightness = float(item.get("brightness", 0.0))
|
||||
raw_contrast = float(item.get("contrast", 0.0))
|
||||
if raw_contrast < 4 or brightness < 8 or brightness > 247:
|
||||
return score * 0.15
|
||||
if raw_contrast < 9:
|
||||
return score * 0.65
|
||||
return score
|
||||
|
||||
|
||||
def _select_keyframes(candidates: list[dict], n: int, target: FrameExtractTarget, dup_threshold: int = 8) -> list[dict]:
|
||||
"""
|
||||
candidates: 按时间排序的候选帧路径
|
||||
candidates: 按时间排序的低清候选帧评分项
|
||||
n: 目标帧数
|
||||
dup_threshold: pHash 汉明距离 < 此值视为相似(默认 8,64bit hash 大致 ~12.5% 像素差)
|
||||
"""
|
||||
if len(candidates) <= n:
|
||||
return candidates
|
||||
|
||||
# 算 pHash + sharpness
|
||||
items = []
|
||||
for i, p in enumerate(candidates):
|
||||
try:
|
||||
img = Image.open(p)
|
||||
h = imagehash.phash(img)
|
||||
s = _sharpness(p)
|
||||
items.append({"path": p, "idx": i, "hash": h, "sharp": s})
|
||||
except Exception:
|
||||
continue
|
||||
_attach_temporal_metrics(candidates)
|
||||
_normalize_item_metrics(candidates)
|
||||
for it in candidates:
|
||||
it["score"] = _target_score(it, target)
|
||||
|
||||
# 去重:相似帧保留 sharpness 高的
|
||||
# 去重:相似帧保留当前目标下分数更高的
|
||||
deduped: list[dict] = []
|
||||
for it in items:
|
||||
for it in candidates:
|
||||
dup = None
|
||||
for kept in deduped:
|
||||
if (it["hash"] - kept["hash"]) < dup_threshold:
|
||||
@@ -421,10 +507,10 @@ def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) ->
|
||||
break
|
||||
if dup is None:
|
||||
deduped.append(it)
|
||||
elif it["sharp"] > dup["sharp"]:
|
||||
elif it["score"] > dup["score"]:
|
||||
deduped[deduped.index(dup)] = it
|
||||
|
||||
# 时序分桶:把候选时间轴等分 n 段,每段取去重后 sharpness 最高的
|
||||
# 时序分桶:把候选时间轴等分 n 段,每段取当前目标下最优的
|
||||
total = len(candidates)
|
||||
buckets: list[list[dict]] = [[] for _ in range(n)]
|
||||
for it in deduped:
|
||||
@@ -434,18 +520,18 @@ def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) ->
|
||||
selected: list[dict] = []
|
||||
for b in buckets:
|
||||
if b:
|
||||
selected.append(max(b, key=lambda x: x["sharp"]))
|
||||
selected.append(max(b, key=lambda x: x["score"]))
|
||||
|
||||
# 空桶补足:从未选的 deduped 里按 sharpness 排序补
|
||||
# 空桶补足:从未选的 deduped 里按目标分数补
|
||||
chosen_paths = {it["path"] for it in selected}
|
||||
remaining = sorted([it for it in deduped if it["path"] not in chosen_paths],
|
||||
key=lambda x: -x["sharp"])
|
||||
key=lambda x: -x["score"])
|
||||
while len(selected) < n and remaining:
|
||||
selected.append(remaining.pop(0))
|
||||
|
||||
# 按时间排序输出
|
||||
selected.sort(key=lambda x: x["idx"])
|
||||
return [it["path"] for it in selected]
|
||||
return selected
|
||||
|
||||
|
||||
def ffprobe_meta(mp4: Path) -> dict:
|
||||
@@ -492,7 +578,11 @@ async def pipeline_download(job_id: str) -> None:
|
||||
update(job, status="failed", error=str(e), message="下载失败")
|
||||
|
||||
|
||||
async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> None:
|
||||
async def pipeline_analyze(
|
||||
job_id: str,
|
||||
frame_count: int = KEYFRAME_COUNT,
|
||||
target: FrameExtractTarget = "balanced",
|
||||
) -> None:
|
||||
"""阶段 2:拆音轨 + 抽关键帧。ASR/翻译是独立文案轨,不阻塞视觉素材流。"""
|
||||
job = JOBS[job_id]
|
||||
d = job_dir(job_id)
|
||||
@@ -510,62 +600,73 @@ async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> No
|
||||
])
|
||||
|
||||
n = max(1, min(int(frame_count), 20))
|
||||
# 候选数:n 的 6 倍或至少 24,封顶 60
|
||||
candidate_count = max(24, min(60, n * 6))
|
||||
target_label = FRAME_TARGET_LABELS.get(target, FRAME_TARGET_LABELS["balanced"])
|
||||
duration = max(float(job.duration or 1.0), 0.1)
|
||||
scan_fps = min(2.0, max(0.02, 180.0 / duration))
|
||||
estimated_scan_count = max(1, int(duration * scan_fps))
|
||||
|
||||
update(job, message=f"抽取候选 {candidate_count} 张…", progress=45)
|
||||
update(job, message=f"低清扫描候选 · {target_label} · 约 {estimated_scan_count} 帧…", progress=45)
|
||||
frames_dir = d / "frames"
|
||||
if frames_dir.exists():
|
||||
shutil.rmtree(frames_dir)
|
||||
frames_dir.mkdir(parents=True)
|
||||
cand_dir = d / "candidates"
|
||||
if cand_dir.exists():
|
||||
shutil.rmtree(cand_dir)
|
||||
cand_dir.mkdir(parents=True)
|
||||
scan_dir = d / "frame_scan"
|
||||
if scan_dir.exists():
|
||||
shutil.rmtree(scan_dir)
|
||||
scan_dir.mkdir(parents=True)
|
||||
|
||||
# 1) 均匀采样大批候选(fast seek,每张 < 0.5s)
|
||||
duration = max(float(job.duration or 1.0), 0.1)
|
||||
step = duration / (candidate_count + 1)
|
||||
candidate_meta: list[tuple[Path, float]] = [] # (path, timestamp)
|
||||
for i in range(candidate_count):
|
||||
t = step * (i + 1)
|
||||
out = cand_dir / f"c_{i:03d}.jpg"
|
||||
# 1) 低分辨率、低帧率扫描。扫描图只用于候选评分,最终不直接作为关键帧。
|
||||
run([
|
||||
"ffmpeg", "-y", "-i", str(mp4),
|
||||
"-vf", f"fps={scan_fps:.4f},scale=360:-2",
|
||||
"-q:v", "4",
|
||||
str(scan_dir / "s_%05d.jpg"),
|
||||
])
|
||||
|
||||
scan_paths = sorted(scan_dir.glob("s_*.jpg"))
|
||||
if not scan_paths:
|
||||
raise RuntimeError("低清扫描没有生成候选帧")
|
||||
|
||||
candidates: list[dict] = []
|
||||
for i, p in enumerate(scan_paths):
|
||||
t = min(i / scan_fps, max(duration - 0.05, 0.0))
|
||||
item = _frame_metrics(p, i, t)
|
||||
if item:
|
||||
candidates.append(item)
|
||||
if not candidates:
|
||||
raise RuntimeError("候选帧评分失败")
|
||||
|
||||
# 2) 目标化筛选:pHash 去重 + 清晰度 / 中心细节 / 转场变化 / 动作强度 + 时序分桶。
|
||||
update(job, message=f"{target_label}筛选 {n} / {len(candidates)} 张…", progress=60)
|
||||
chosen = _select_keyframes(candidates, n, target)
|
||||
|
||||
# 3) 只对最终选中的时间点,从原视频抽高质量关键帧。
|
||||
renamed: list[KeyFrame] = []
|
||||
chosen_sorted = sorted(chosen, key=lambda it: float(it["timestamp"]))
|
||||
for i, item in enumerate(chosen_sorted):
|
||||
dst = frames_dir / f"{i:03d}.jpg"
|
||||
t = float(item["timestamp"])
|
||||
run([
|
||||
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
|
||||
"ffmpeg", "-y", "-ss", f"{t:.3f}", "-i", str(mp4),
|
||||
"-frames:v", "1",
|
||||
"-pix_fmt", "yuvj420p", "-q:v", "3",
|
||||
str(out),
|
||||
str(dst),
|
||||
])
|
||||
if out.exists():
|
||||
candidate_meta.append((out, t))
|
||||
|
||||
# 2) D 启发式选 n 张:pHash 去重 + Laplacian 清晰度 + 时序分桶
|
||||
update(job, message=f"启发式筛选 {n} / {len(candidate_meta)} 张…", progress=60)
|
||||
cand_paths = [m[0] for m in candidate_meta]
|
||||
ts_by_path = {m[0]: m[1] for m in candidate_meta}
|
||||
chosen = _select_keyframes(cand_paths, n)
|
||||
|
||||
# 3) 落盘到 frames/<idx>.jpg
|
||||
renamed: list[KeyFrame] = []
|
||||
chosen_sorted = sorted(chosen, key=lambda p: ts_by_path[p])
|
||||
for i, src in enumerate(chosen_sorted):
|
||||
dst = frames_dir / f"{i:03d}.jpg"
|
||||
shutil.copyfile(src, dst)
|
||||
renamed.append(KeyFrame(
|
||||
index=i,
|
||||
timestamp=round(ts_by_path[src], 2),
|
||||
timestamp=round(t, 2),
|
||||
url=f"/jobs/{job_id}/frames/{i}.jpg",
|
||||
))
|
||||
|
||||
# 4) 清理候选目录
|
||||
shutil.rmtree(cand_dir, ignore_errors=True)
|
||||
# 4) 清理扫描目录
|
||||
shutil.rmtree(scan_dir, ignore_errors=True)
|
||||
|
||||
update(
|
||||
job,
|
||||
status="frames_extracted",
|
||||
frames=renamed,
|
||||
progress=70,
|
||||
message=f"已抽取 {len(renamed)} 张关键帧 · 可继续清洗 / 提取元素 / 分镜编排",
|
||||
message=f"已按「{target_label}」抽取 {len(renamed)} 张关键帧 · 可继续清洗 / 提取元素 / 分镜编排",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -934,13 +1035,18 @@ async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(..
|
||||
|
||||
|
||||
@app.post("/jobs/{job_id}/analyze", response_model=Job)
|
||||
async def trigger_analyze(job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT) -> Job:
|
||||
async def trigger_analyze(
|
||||
job_id: str,
|
||||
bg: BackgroundTasks,
|
||||
frames: int = KEYFRAME_COUNT,
|
||||
target: FrameExtractTarget = "balanced",
|
||||
) -> Job:
|
||||
job = JOBS.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(404, "job not found")
|
||||
if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}:
|
||||
raise HTTPException(409, f"status must be downloaded/failed, got {job.status}")
|
||||
bg.add_task(pipeline_analyze, job_id, frames)
|
||||
bg.add_task(pipeline_analyze, job_id, frames, target)
|
||||
return job
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user