feat: add xai video model

This commit is contained in:
2026-06-03 16:59:43 +08:00
parent e14acee2a7
commit d038f1b2f4
8 changed files with 228 additions and 56 deletions

View File

@@ -350,9 +350,31 @@ VIDEO_MODEL_ALIASES = {
"veo3": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
"veo": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
"voe": env_video_model("VIDEO_MODEL_VEO3", "veo-3.1-fast"),
"grok_imagine_video": env_video_model("VIDEO_MODEL_XAI", "grok-imagine-video"),
"grok-imagine-video": env_video_model("VIDEO_MODEL_XAI", "grok-imagine-video"),
"xai": env_video_model("VIDEO_MODEL_XAI", "grok-imagine-video"),
}
VIDEO_API_BASE_URL = os.getenv("VIDEO_API_BASE_URL", "").strip()
VIDEO_API_KEY = os.getenv("VIDEO_API_KEY", "").strip()
_VIDEO_XAI_BASE_DEFAULT = (
VIDEO_API_BASE_URL
if "xai" in VIDEO_API_BASE_URL.lower()
else "https://ai.skg.com/ezlink/xai"
)
XAI_VIDEO_API_BASE_URL = (
os.getenv("XAI_VIDEO_API_BASE_URL")
or os.getenv("XAI_GATEWAY_BASE")
or _VIDEO_XAI_BASE_DEFAULT
).strip().rstrip("/")
XAI_VIDEO_API_KEY = (
os.getenv("XAI_VIDEO_API_KEY")
or os.getenv("XAI_GATEWAY_KEY")
or (VIDEO_API_KEY if "xai" in VIDEO_API_BASE_URL.lower() else "")
).strip()
XAI_VIDEO_MODEL = VIDEO_MODEL_ALIASES["xai"]
XAI_VIDEO_CREATE_PATH = os.getenv("XAI_VIDEO_CREATE_PATH", "/v1/videos/generations").strip() or "/v1/videos/generations"
XAI_VIDEO_STATUS_PATH = os.getenv("XAI_VIDEO_STATUS_PATH", "/v1/videos/{id}").strip() or "/v1/videos/{id}"
XAI_VIDEO_CONTENT_PATH = os.getenv("XAI_VIDEO_CONTENT_PATH", "").strip()
WEB_AUTH_USERNAME = os.getenv("WEB_AUTH_USERNAME", "").strip()
WEB_AUTH_PASSWORD = os.getenv("WEB_AUTH_PASSWORD", "").strip()
WEB_AUTH_SESSION_SECRET = os.getenv("WEB_AUTH_SESSION_SECRET", "").strip()
@@ -389,6 +411,12 @@ WEB_AUTH_CONFIGURED = bool(PASSWORD_AUTH_CONFIGURED or FEISHU_AUTH_CONFIGURED)
def default_video_gateway_paths(base_url: str) -> tuple[str, str, str]:
base = base_url.strip().rstrip("/").lower()
if "api.x.ai" in base or "/ezlink/xai" in base:
return (
"/v1/videos/generations",
"/v1/videos/{id}",
"",
)
if "ai.skg.com/doubao" in base:
return (
"/api/v3/contents/generations/tasks",
@@ -1446,13 +1474,30 @@ def video_uses_poe() -> bool:
return bool(POE_API_KEY)
def video_uses_ark() -> bool:
base = video_api_base()
def is_xai_video_model(model: str | None) -> bool:
value = (model or "").strip().lower()
if not value:
value = (VIDEO_MODEL or "").strip().lower()
resolved = VIDEO_MODEL_ALIASES.get(value, value).strip().lower()
xai_model = (XAI_VIDEO_MODEL or "grok-imagine-video").strip().lower()
return resolved == xai_model or resolved.startswith("grok-imagine-video")
def video_uses_xai(model: str | None = None) -> bool:
return is_xai_video_model(model) or "api.x.ai" in video_api_base(model).lower() or "/ezlink/xai" in video_api_base(model).lower()
def video_uses_ark(model: str | None = None) -> bool:
if video_uses_xai(model):
return False
base = video_api_base(model)
return "ark.cn-beijing.volces.com" in base or "ai.skg.com/doubao" in base
def video_provider_name() -> str:
base = video_api_base()
def video_provider_name(model: str | None = None) -> str:
base = video_api_base(model)
if video_uses_xai(model):
return "xai"
if video_uses_poe():
return "poe"
if "ai.skg.com/doubao" in base:
@@ -1462,7 +1507,9 @@ def video_provider_name() -> str:
return "custom"
def video_api_base() -> str:
def video_api_base(model: str | None = None) -> str:
if is_xai_video_model(model):
return XAI_VIDEO_API_BASE_URL.rstrip("/")
if VIDEO_API_BASE_URL:
return VIDEO_API_BASE_URL.rstrip("/")
if POE_API_KEY:
@@ -1470,7 +1517,13 @@ def video_api_base() -> str:
return (LLM_BASE_URL or "https://api.openai.com/v1").rstrip("/")
def video_api_key() -> str:
def video_api_key(model: str | None = None) -> str:
if is_xai_video_model(model):
if XAI_VIDEO_API_KEY:
return XAI_VIDEO_API_KEY
if "xai" in VIDEO_API_BASE_URL.lower() and VIDEO_API_KEY:
return VIDEO_API_KEY
return ""
if VIDEO_API_KEY:
return VIDEO_API_KEY
if video_uses_poe():
@@ -1478,14 +1531,26 @@ def video_api_key() -> str:
return LLM_API_KEY
def video_create_paths(model: str | None = None) -> list[str]:
return [XAI_VIDEO_CREATE_PATH] if video_uses_xai(model) else VIDEO_CREATE_PATHS
def video_status_path(model: str | None = None) -> str:
return XAI_VIDEO_STATUS_PATH if video_uses_xai(model) else VIDEO_STATUS_PATH
def video_content_path(model: str | None = None) -> str:
return XAI_VIDEO_CONTENT_PATH if video_uses_xai(model) else VIDEO_CONTENT_PATH
def video_path(template: str, **values: str) -> str:
path = template.format(**values)
return path if path.startswith("/") else f"/{path}"
def ensure_video_api_configured() -> None:
if not video_api_key():
raise HTTPException(503, "POE_API_KEY、VIDEO_API_KEY 或 LLM_API_KEY 未配置,无法调用生视频 API")
def ensure_video_api_configured(model: str | None = None) -> None:
if not video_api_key(model):
raise HTTPException(503, "POE_API_KEY、VIDEO_API_KEY、XAI_VIDEO_API_KEY 或 LLM_API_KEY 未配置,无法调用生视频 API")
def storyboard_ref_path(job_id: str, ref: dict | None) -> Path | None:
@@ -4973,13 +5038,16 @@ def _image_size_payload(raw: str | None, model: str | None = None) -> dict:
return {} if size == "auto" else {"size": size}
def video_duration_options() -> list[int]:
if video_uses_ark():
def video_duration_options(model: str | None = None) -> list[int]:
if video_uses_ark(model) or video_uses_xai(model):
return [5, 8, 10, 12, 15]
return [4, 8, 12]
def video_size_options() -> list[dict]:
def video_size_options(model: str | None = None) -> list[dict]:
if video_uses_xai(model):
allowed = {"720x1280", "1280x720", "1024x1024"}
return [item for item in VIDEO_SIZE_CHOICES if str(item["value"]) in allowed]
return VIDEO_SIZE_CHOICES
@@ -4992,7 +5060,9 @@ def _video_resolution_choice(value: str) -> dict:
def _video_resolution_values_for_model(model: str | None) -> list[str]:
concrete = (model or "").strip().lower()
if video_uses_ark():
if video_uses_xai(concrete):
return ["480p", "720p"]
if video_uses_ark(concrete):
if "seedance-2-0-fast" in concrete:
return ["480p", "720p"]
if "seedance-2-0" in concrete or "seedance-1-5-pro" in concrete or "seedance-1-0-pro" in concrete:
@@ -5029,7 +5099,7 @@ def _normalize_video_resolution(raw: str | None, model: str | None = None) -> st
return value
def _normalize_video_size(raw: str | None) -> str:
def _normalize_video_size(raw: str | None, model: str | None = None) -> str:
value = (raw or "720x1280").strip().lower().replace(" ", "")
aliases = {
"vertical": "720x1280",
@@ -5046,7 +5116,7 @@ def _normalize_video_size(raw: str | None) -> str:
"3:4": "960x1280",
}
value = aliases.get(value, value)
allowed = {str(item["value"]) for item in VIDEO_SIZE_CHOICES}
allowed = {str(item["value"]) for item in video_size_options(model)}
if value not in allowed:
raise HTTPException(400, f"unsupported video size: {raw}")
return value
@@ -5060,14 +5130,18 @@ def video_model_options() -> list[dict]:
"veo3": "Veo 3",
"veo": "Veo",
"voe": "Veo",
"xai": "Grok Imagine Video",
"grok_imagine_video": "Grok Imagine Video",
"grok-imagine-video": "Grok Imagine Video",
}
concrete_label_map = {
"doubao-seedance-2-0-fast-260128": "Seedance 2.0 Fast",
"doubao-seedance-2-0-260128": "Seedance 2.0 高清",
"grok-imagine-video": "Grok Imagine Video",
}
seen_models: set[str] = set()
options: list[dict] = []
for key in ["seedance", "seedance_hd", "kling", "veo3", "veo"]:
for key in ["seedance", "seedance_hd", "xai", "kling", "veo3", "veo"]:
if key not in VIDEO_MODEL_ALIASES:
continue
model = VIDEO_MODEL_ALIASES[key]
@@ -5078,13 +5152,14 @@ def video_model_options() -> list[dict]:
"id": key,
"label": concrete_label_map.get(model, label_map.get(key, key)),
"model": model,
"description": f"当前视频网关可选模型;单次时长最高 {max(video_duration_options())}",
"duration_options": video_duration_options(),
"size_options": video_size_options(),
"provider": video_provider_name(model),
"description": f"当前视频网关可选模型;单次时长最高 {max(video_duration_options(model))}",
"duration_options": video_duration_options(model),
"size_options": video_size_options(model),
"resolution_options": video_resolution_options(model),
"default_resolution": default_video_resolution(model),
"max_duration_seconds": max(video_duration_options()),
"available": bool(video_api_key()),
"max_duration_seconds": max(video_duration_options(model)),
"available": bool(video_api_key(model)),
})
default_model = resolve_video_model(VIDEO_MODEL)
if not any(item["id"] == VIDEO_MODEL or item["model"] == default_model for item in options):
@@ -5092,13 +5167,14 @@ def video_model_options() -> list[dict]:
"id": VIDEO_MODEL,
"label": label_map.get(VIDEO_MODEL, VIDEO_MODEL),
"model": default_model,
"provider": video_provider_name(default_model),
"description": "默认视频模型",
"duration_options": video_duration_options(),
"size_options": video_size_options(),
"duration_options": video_duration_options(default_model),
"size_options": video_size_options(default_model),
"resolution_options": video_resolution_options(default_model),
"default_resolution": default_video_resolution(default_model),
"max_duration_seconds": max(video_duration_options()),
"available": bool(video_api_key()),
"max_duration_seconds": max(video_duration_options(default_model)),
"available": bool(video_api_key(default_model)),
})
return options
@@ -6585,6 +6661,10 @@ def health() -> dict:
"video_base_url": video_api_base(),
"video_configured": bool(video_api_key()),
"video_create_paths": VIDEO_CREATE_PATHS,
"xai_video_model": XAI_VIDEO_MODEL,
"xai_video_base_url": XAI_VIDEO_API_BASE_URL,
"xai_video_configured": bool(video_api_key(XAI_VIDEO_MODEL)),
"xai_video_create_path": XAI_VIDEO_CREATE_PATH,
},
}
@@ -8832,8 +8912,8 @@ class ProductFusionDescriptionReq(BaseModel):
shots: list[ProductFusionShot] = Field(default_factory=list)
def video_seconds(duration: float) -> str:
if video_uses_ark():
def video_seconds(duration: float, model: str | None = None) -> str:
if video_uses_ark(model) or video_uses_xai(model):
if duration <= 0:
return "5"
return str(max(4, min(15, round(duration))))
@@ -8848,7 +8928,7 @@ def resolve_video_model(raw: str | None) -> str:
requested = (raw or VIDEO_MODEL or "seedance").strip()
lowered = requested.lower()
if lowered in {"sora", "sora-2", "sora_2"}:
raise HTTPException(400, "Sora 已停用,请选择当前已接入的 Seedance")
raise HTTPException(400, "Sora 已停用,请选择当前已接入的 Seedance 或 Grok Imagine Video")
return VIDEO_MODEL_ALIASES.get(lowered, requested)
@@ -8897,6 +8977,12 @@ def video_url_from_response(data: dict) -> str:
v = content.get(key)
if isinstance(v, str) and v:
return v
video = data.get("video")
if isinstance(video, dict):
for key in ("url", "video_url", "download_url", "file_url"):
v = video.get(key)
if isinstance(v, str) and v:
return v
return ""
@@ -8987,12 +9073,15 @@ def _video_create_failure_message(create_errors: list[str]) -> str:
return "视频生成失败:视频模型没有接受本次请求。请换一张参考图或简化提示词后重试;如果持续失败,请联系管理员。"
def download_generated_video(client, base: str, headers: dict, provider_id: str, direct_url: str, out_mp4: Path) -> None:
def download_generated_video(client, base: str, headers: dict, provider_id: str, direct_url: str, out_mp4: Path, model: str | None = None) -> None:
if direct_url:
url = direct_url if direct_url.startswith("http") else f"{base}{direct_url if direct_url.startswith('/') else '/' + direct_url}"
r = client.get(url, headers=headers if url.startswith(base) else None)
else:
r = client.get(f"{base}{video_path(VIDEO_CONTENT_PATH, id=provider_id)}", headers=headers)
content_path = video_content_path(model)
if not content_path:
raise RuntimeError("视频生成完成但未返回可下载地址")
r = client.get(f"{base}{video_path(content_path, id=provider_id)}", headers=headers)
r.raise_for_status()
out_mp4.write_bytes(r.content)
@@ -9032,7 +9121,33 @@ def submit_video_create(
product_imgs: list[Path] | None = None,
primary_role: str = "first_frame",
):
if video_uses_ark():
model = str(payload.get("model") or "")
if video_uses_xai(model):
duration = int(float(str(payload.get("duration") or payload.get(VIDEO_DURATION_FIELD) or 8)))
data: dict = {
"model": model,
"prompt": payload["prompt"],
"duration": max(1, duration),
"aspect_ratio": size_to_video_ratio(str(payload.get("size", ""))),
"resolution": _normalize_video_resolution(str(payload.get("resolution") or ""), model),
}
reference_images: list[dict] = []
if ref_img.exists() and primary_role:
ref_payload = {"url": ark_reference_data_url(ref_img)}
if primary_role == "first_frame":
data["image"] = ref_payload
else:
reference_images.append(ref_payload)
if last_img and last_img.exists():
reference_images.append({"url": ark_reference_data_url(last_img)})
for product_img in (product_imgs or [])[:6]:
if product_img.exists():
reference_images.append({"url": ark_reference_data_url(product_img)})
if reference_images:
data["reference_images"] = reference_images[:6]
return client.post(url, headers={**headers, "Content-Type": "application/json"}, json=data)
if video_uses_ark(model):
content = [{"type": "text", "text": payload["prompt"]}]
if source_ref and source_ref.kind == "source_video" and source_ref.url:
content.append(
@@ -9046,7 +9161,7 @@ def submit_video_create(
{
"type": "image_url",
"image_url": {"url": ark_reference_data_url(ref_img)},
"role": primary_role,
"role": primary_role or "reference_image",
}
)
if last_img and last_img.exists():
@@ -9112,8 +9227,8 @@ def render_storyboard_video(
ref_img = out_dir / "reference.jpg"
last_img = out_dir / "last_reference.jpg"
out_mp4 = out_dir / "video.mp4"
base = video_api_base()
headers = {"Authorization": f"Bearer {video_api_key()}"}
base = video_api_base(model)
headers = {"Authorization": f"Bearer {video_api_key(model)}"}
try:
prepare_video_reference(ref_path, ref_img)
@@ -9133,15 +9248,15 @@ def render_storyboard_video(
payload[VIDEO_DURATION_FIELD] = seconds
create = None
create_errors: list[str] = []
for create_path in VIDEO_CREATE_PATHS:
for create_path in video_create_paths(model):
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, source_ref, prepared_last_img, prepared_product_imgs, primary_role)
if video_uses_ark() and source_ref and resp.status_code in {400, 422}:
if video_uses_ark(model) and source_ref and resp.status_code in {400, 422}:
create_errors.append(f"{video_path(create_path)} + reference_video -> HTTP {resp.status_code}: {resp.text[:700]}")
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, prepared_product_imgs, primary_role)
if video_uses_ark() and prepared_last_img and resp.status_code in {400, 422}:
if video_uses_ark(model) and prepared_last_img and resp.status_code in {400, 422}:
create_errors.append(f"{video_path(create_path)} + last_frame -> HTTP {resp.status_code}: {resp.text[:700]}")
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, None, prepared_product_imgs, primary_role)
if video_uses_ark() and prepared_product_imgs and resp.status_code in {400, 422}:
if video_uses_ark(model) and prepared_product_imgs and resp.status_code in {400, 422}:
create_errors.append(f"{video_path(create_path)} + product_reference -> HTTP {resp.status_code}: {resp.text[:700]}")
resp = submit_video_create(client, f"{base}{video_path(create_path)}", headers, ref_img, payload, None, prepared_last_img, None, primary_role)
if resp.status_code < 400:
@@ -9154,7 +9269,7 @@ def render_storyboard_video(
print(f"[video create failed] job={job_id} video={local_id} errors={' | '.join(create_errors)[:1800]}", flush=True)
raise RuntimeError(_video_create_failure_message(create_errors))
data = create.json()
video_api_id = data.get("id") or provider_id or local_id
video_api_id = data.get("request_id") or data.get("id") or provider_id or local_id
status = normalize_video_status(data.get("status"))
progress = video_progress(data, 5)
direct_url = video_url_from_response(data)
@@ -9171,7 +9286,7 @@ def render_storyboard_video(
deadline = time.time() + VIDEO_POLL_TIMEOUT_SECONDS
while status in {"queued", "in_progress"} and time.time() < deadline:
time.sleep(8)
poll = client.get(f"{base}{video_path(VIDEO_STATUS_PATH, id=video_api_id)}", headers=headers)
poll = client.get(f"{base}{video_path(video_status_path(model), id=video_api_id)}", headers=headers)
poll.raise_for_status()
pdata = poll.json()
status = normalize_video_status(pdata.get("status"))
@@ -9200,7 +9315,7 @@ def render_storyboard_video(
update_generated_video(job_id, local_id, status="failed", error=_video_public_error(raw_error or f"video status: {status}"), progress=progress, queue_message="")
return
download_generated_video(client, base, headers, video_api_id, direct_url, out_mp4)
download_generated_video(client, base, headers, video_api_id, direct_url, out_mp4, model)
update_generated_video(
job_id,
local_id,
@@ -9286,7 +9401,6 @@ def refine_storyboard(job_id: str, idx: int, req: RefineStoryboardReq) -> dict:
def _enqueue_storyboard_videos(job: Job, frame: KeyFrame, req: GenerateStoryboardVideoReq, bg: BackgroundTasks | None = None) -> list[str]:
ensure_video_api_configured()
prompt = _ensure_english(req.prompt.strip())
if not prompt and frame.storyboard:
prompt = _storyboard_video_prompt(frame.storyboard, req.seed)
@@ -9295,7 +9409,7 @@ def _enqueue_storyboard_videos(job: Job, frame: KeyFrame, req: GenerateStoryboar
count = max(1, min(12, int(req.count or 1)))
ref = req.first_image or req.subject_image or req.product_image or req.scene_image or req.action_image
primary_role = "first_frame" if req.first_image else "reference_image"
primary_role = "first_frame" if req.first_image else ("reference_image" if ref else "")
ref_path = storyboard_ref_path(job.id, ref) or (job_dir(job.id) / "frames" / f"{frame.index:03d}.jpg")
if not ref_path.exists():
raise HTTPException(404, "reference image missing")
@@ -9315,13 +9429,23 @@ def _enqueue_storyboard_videos(job: Job, frame: KeyFrame, req: GenerateStoryboar
seen_ref_paths.add(key)
model = resolve_video_model(req.model)
seconds = video_seconds(float(req.duration or 4))
video_size = _normalize_video_size(req.size)
ensure_video_api_configured(model)
seconds = video_seconds(float(req.duration or 4), model)
video_size = _normalize_video_size(req.size, model)
video_resolution = _normalize_video_resolution(req.resolution, model)
source_ref = req.source_ref
if source_ref and source_ref.kind == "source_video" and not source_ref.url:
source_ref = None
has_visual_reference = bool(ref_path.exists() or last_ref_path or reference_ref_paths)
has_visual_reference = bool(
req.first_image
or req.subject_image
or req.product_image
or req.scene_image
or req.action_image
or req.last_image
or raw_product_refs
or req.subject_images
)
items: list[GeneratedVideo] = []
ids: list[str] = []
queued_tasks: list[tuple[str, tuple]] = []