auto-save 2026-05-14 03:53 (~5)

This commit is contained in:
2026-05-14 03:53:51 +08:00
parent 9572111254
commit 6eb1f98e06
5 changed files with 234 additions and 68 deletions

View File

@@ -88,6 +88,14 @@ JobStatus = Literal[
]
KEYFRAME_COUNT = int(os.getenv("KEYFRAME_COUNT", "5"))
FrameExtractTarget = Literal["balanced", "subject", "transition", "expression", "motion"]
FRAME_TARGET_LABELS: dict[FrameExtractTarget, str] = {
"balanced": "综合关键帧",
"subject": "清晰主体",
"transition": "转场变化",
"expression": "表情瞬间",
"motion": "动作峰值",
}
class GeneratedImage(BaseModel):
@@ -383,37 +391,115 @@ import numpy as np
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
def _sharpness(img_path: Path) -> float:
def _sharpness_from_gray(g: np.ndarray) -> float:
"""Laplacian variance值越大越清晰模糊/转场帧值低。"""
g = np.asarray(Image.open(img_path).convert("L").resize((320, 180)), dtype=np.float32)
lap = (-4 * g[1:-1, 1:-1]
+ g[:-2, 1:-1] + g[2:, 1:-1] + g[1:-1, :-2] + g[1:-1, 2:])
return float(lap.var())
def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) -> list[Path]:
def _frame_metrics(img_path: Path, idx: int, timestamp: float) -> dict | None:
"""低清候选帧的本地评分特征。只用于排序,最终仍从原视频抽原尺寸帧。"""
try:
with Image.open(img_path) as raw:
img = raw.convert("RGB")
h = imagehash.phash(img)
small = img.resize((160, 90))
except Exception:
return None
arr = np.asarray(small, dtype=np.float32)
# Rec. 601 luma保留 0-255 范围,便于和清晰度 / 对比度阈值一起看。
gray = (0.299 * arr[:, :, 0] + 0.587 * arr[:, :, 1] + 0.114 * arr[:, :, 2]).astype(np.float32)
center = gray[22:68, 40:120]
rg = arr[:, :, 0] - arr[:, :, 1]
yb = 0.5 * (arr[:, :, 0] + arr[:, :, 1]) - arr[:, :, 2]
colorfulness = float(np.sqrt(rg.var() + yb.var()) + 0.3 * np.sqrt(rg.mean() ** 2 + yb.mean() ** 2))
return {
"path": img_path,
"idx": idx,
"timestamp": timestamp,
"hash": h,
"gray": gray,
"sharp": _sharpness_from_gray(gray),
"center_sharp": _sharpness_from_gray(center),
"brightness": float(gray.mean()),
"contrast": float(gray.std()),
"colorfulness": colorfulness,
"scene_score": 0.0,
"motion": 0.0,
}
def _attach_temporal_metrics(items: list[dict]) -> None:
"""相邻低清帧差异:转场 / 动作目标依赖它,不需要逐帧高分辨率扫描。"""
for i, it in enumerate(items):
prev_delta = 0.0
next_delta = 0.0
if i > 0:
prev_delta = float(np.mean(np.abs(it["gray"] - items[i - 1]["gray"])) / 255.0)
if i + 1 < len(items):
next_delta = float(np.mean(np.abs(items[i + 1]["gray"] - it["gray"])) / 255.0)
it["scene_score"] = max(prev_delta, next_delta)
it["motion"] = (prev_delta + next_delta) / 2.0
def _normalize_item_metrics(items: list[dict]) -> None:
for key in ("sharp", "center_sharp", "contrast", "colorfulness", "scene_score", "motion"):
vals = [float(it.get(key, 0.0)) for it in items if float(it.get(key, 0.0)) > 0]
cap = float(np.percentile(vals, 95)) if vals else 1.0
if cap <= 0:
cap = 1.0
for it in items:
it[f"{key}_n"] = min(float(it.get(key, 0.0)) / cap, 1.0)
def _target_score(item: dict, target: FrameExtractTarget) -> float:
sharp = float(item.get("sharp_n", 0.0))
center = float(item.get("center_sharp_n", 0.0))
contrast = float(item.get("contrast_n", 0.0))
color = float(item.get("colorfulness_n", 0.0))
scene = float(item.get("scene_score_n", 0.0))
motion = float(item.get("motion_n", 0.0))
if target == "subject":
score = center * 0.48 + sharp * 0.25 + contrast * 0.17 + color * 0.10
elif target == "transition":
score = scene * 0.55 + sharp * 0.28 + contrast * 0.12 + color * 0.05
elif target == "expression":
# 没有额外视觉模型时,表情/动物瞬间只能用中心细节 + 清晰 + 轻微动作变化做本地近似。
score = center * 0.40 + sharp * 0.24 + motion * 0.18 + contrast * 0.12 + color * 0.06
elif target == "motion":
score = motion * 0.45 + sharp * 0.30 + center * 0.15 + contrast * 0.10
else:
score = sharp * 0.45 + scene * 0.22 + center * 0.15 + contrast * 0.12 + color * 0.06
brightness = float(item.get("brightness", 0.0))
raw_contrast = float(item.get("contrast", 0.0))
if raw_contrast < 4 or brightness < 8 or brightness > 247:
return score * 0.15
if raw_contrast < 9:
return score * 0.65
return score
def _select_keyframes(candidates: list[dict], n: int, target: FrameExtractTarget, dup_threshold: int = 8) -> list[dict]:
"""
candidates: 按时间排序的候选帧路径
candidates: 按时间排序的低清候选帧评分项
n: 目标帧数
dup_threshold: pHash 汉明距离 < 此值视为相似(默认 864bit hash 大致 ~12.5% 像素差)
"""
if len(candidates) <= n:
return candidates
# 算 pHash + sharpness
items = []
for i, p in enumerate(candidates):
try:
img = Image.open(p)
h = imagehash.phash(img)
s = _sharpness(p)
items.append({"path": p, "idx": i, "hash": h, "sharp": s})
except Exception:
continue
_attach_temporal_metrics(candidates)
_normalize_item_metrics(candidates)
for it in candidates:
it["score"] = _target_score(it, target)
# 去重:相似帧保留 sharpness 高的
# 去重:相似帧保留当前目标下分数更高的
deduped: list[dict] = []
for it in items:
for it in candidates:
dup = None
for kept in deduped:
if (it["hash"] - kept["hash"]) < dup_threshold:
@@ -421,10 +507,10 @@ def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) ->
break
if dup is None:
deduped.append(it)
elif it["sharp"] > dup["sharp"]:
elif it["score"] > dup["score"]:
deduped[deduped.index(dup)] = it
# 时序分桶:把候选时间轴等分 n 段,每段取去重后 sharpness 最高
# 时序分桶:把候选时间轴等分 n 段,每段取当前目标下最优
total = len(candidates)
buckets: list[list[dict]] = [[] for _ in range(n)]
for it in deduped:
@@ -434,18 +520,18 @@ def _select_keyframes(candidates: list[Path], n: int, dup_threshold: int = 8) ->
selected: list[dict] = []
for b in buckets:
if b:
selected.append(max(b, key=lambda x: x["sharp"]))
selected.append(max(b, key=lambda x: x["score"]))
# 空桶补足:从未选的 deduped 里按 sharpness 排序
# 空桶补足:从未选的 deduped 里按目标分数
chosen_paths = {it["path"] for it in selected}
remaining = sorted([it for it in deduped if it["path"] not in chosen_paths],
key=lambda x: -x["sharp"])
key=lambda x: -x["score"])
while len(selected) < n and remaining:
selected.append(remaining.pop(0))
# 按时间排序输出
selected.sort(key=lambda x: x["idx"])
return [it["path"] for it in selected]
return selected
def ffprobe_meta(mp4: Path) -> dict:
@@ -492,7 +578,11 @@ async def pipeline_download(job_id: str) -> None:
update(job, status="failed", error=str(e), message="下载失败")
async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> None:
async def pipeline_analyze(
job_id: str,
frame_count: int = KEYFRAME_COUNT,
target: FrameExtractTarget = "balanced",
) -> None:
"""阶段 2拆音轨 + 抽关键帧。ASR/翻译是独立文案轨,不阻塞视觉素材流。"""
job = JOBS[job_id]
d = job_dir(job_id)
@@ -510,62 +600,73 @@ async def pipeline_analyze(job_id: str, frame_count: int = KEYFRAME_COUNT) -> No
])
n = max(1, min(int(frame_count), 20))
# 候选数n 的 6 倍或至少 24封顶 60
candidate_count = max(24, min(60, n * 6))
target_label = FRAME_TARGET_LABELS.get(target, FRAME_TARGET_LABELS["balanced"])
duration = max(float(job.duration or 1.0), 0.1)
scan_fps = min(2.0, max(0.02, 180.0 / duration))
estimated_scan_count = max(1, int(duration * scan_fps))
update(job, message=f"抽取候选 {candidate_count} ", progress=45)
update(job, message=f"低清扫描候选 · {target_label} · 约 {estimated_scan_count} ", progress=45)
frames_dir = d / "frames"
if frames_dir.exists():
shutil.rmtree(frames_dir)
frames_dir.mkdir(parents=True)
cand_dir = d / "candidates"
if cand_dir.exists():
shutil.rmtree(cand_dir)
cand_dir.mkdir(parents=True)
scan_dir = d / "frame_scan"
if scan_dir.exists():
shutil.rmtree(scan_dir)
scan_dir.mkdir(parents=True)
# 1) 均匀采样大批候选fast seek每张 < 0.5s
duration = max(float(job.duration or 1.0), 0.1)
step = duration / (candidate_count + 1)
candidate_meta: list[tuple[Path, float]] = [] # (path, timestamp)
for i in range(candidate_count):
t = step * (i + 1)
out = cand_dir / f"c_{i:03d}.jpg"
# 1) 低分辨率、低帧率扫描。扫描图只用于候选评分,最终不直接作为关键帧。
run([
"ffmpeg", "-y", "-i", str(mp4),
"-vf", f"fps={scan_fps:.4f},scale=360:-2",
"-q:v", "4",
str(scan_dir / "s_%05d.jpg"),
])
scan_paths = sorted(scan_dir.glob("s_*.jpg"))
if not scan_paths:
raise RuntimeError("低清扫描没有生成候选帧")
candidates: list[dict] = []
for i, p in enumerate(scan_paths):
t = min(i / scan_fps, max(duration - 0.05, 0.0))
item = _frame_metrics(p, i, t)
if item:
candidates.append(item)
if not candidates:
raise RuntimeError("候选帧评分失败")
# 2) 目标化筛选pHash 去重 + 清晰度 / 中心细节 / 转场变化 / 动作强度 + 时序分桶。
update(job, message=f"{target_label}筛选 {n} / {len(candidates)} 张…", progress=60)
chosen = _select_keyframes(candidates, n, target)
# 3) 只对最终选中的时间点,从原视频抽高质量关键帧。
renamed: list[KeyFrame] = []
chosen_sorted = sorted(chosen, key=lambda it: float(it["timestamp"]))
for i, item in enumerate(chosen_sorted):
dst = frames_dir / f"{i:03d}.jpg"
t = float(item["timestamp"])
run([
"ffmpeg", "-y", "-ss", str(t), "-i", str(mp4),
"ffmpeg", "-y", "-ss", f"{t:.3f}", "-i", str(mp4),
"-frames:v", "1",
"-pix_fmt", "yuvj420p", "-q:v", "3",
str(out),
str(dst),
])
if out.exists():
candidate_meta.append((out, t))
# 2) D 启发式选 n 张pHash 去重 + Laplacian 清晰度 + 时序分桶
update(job, message=f"启发式筛选 {n} / {len(candidate_meta)} 张…", progress=60)
cand_paths = [m[0] for m in candidate_meta]
ts_by_path = {m[0]: m[1] for m in candidate_meta}
chosen = _select_keyframes(cand_paths, n)
# 3) 落盘到 frames/<idx>.jpg
renamed: list[KeyFrame] = []
chosen_sorted = sorted(chosen, key=lambda p: ts_by_path[p])
for i, src in enumerate(chosen_sorted):
dst = frames_dir / f"{i:03d}.jpg"
shutil.copyfile(src, dst)
renamed.append(KeyFrame(
index=i,
timestamp=round(ts_by_path[src], 2),
timestamp=round(t, 2),
url=f"/jobs/{job_id}/frames/{i}.jpg",
))
# 4) 清理候选目录
shutil.rmtree(cand_dir, ignore_errors=True)
# 4) 清理扫描目录
shutil.rmtree(scan_dir, ignore_errors=True)
update(
job,
status="frames_extracted",
frames=renamed,
progress=70,
message=f"已抽取 {len(renamed)} 张关键帧 · 可继续清洗 / 提取元素 / 分镜编排",
message=f"按「{target_label}抽取 {len(renamed)} 张关键帧 · 可继续清洗 / 提取元素 / 分镜编排",
)
except Exception as e:
@@ -934,13 +1035,18 @@ async def create_job_from_upload(bg: BackgroundTasks, file: UploadFile = File(..
@app.post("/jobs/{job_id}/analyze", response_model=Job)
async def trigger_analyze(job_id: str, bg: BackgroundTasks, frames: int = KEYFRAME_COUNT) -> Job:
async def trigger_analyze(
job_id: str,
bg: BackgroundTasks,
frames: int = KEYFRAME_COUNT,
target: FrameExtractTarget = "balanced",
) -> Job:
job = JOBS.get(job_id)
if not job:
raise HTTPException(404, "job not found")
if job.status not in {"downloaded", "frames_extracted", "transcribed", "failed"}:
raise HTTPException(409, f"status must be downloaded/failed, got {job.status}")
bg.add_task(pipeline_analyze, job_id, frames)
bg.add_task(pipeline_analyze, job_id, frames, target)
return job