36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
"""Middleware chain — runs ordered middlewares before/after graph execution."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
|
|
from app.graph.state import ReportState
|
|
from .base import Middleware
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MiddlewareChain:
|
|
"""Ordered list of middlewares. before() runs in order, after() runs in reverse."""
|
|
|
|
def __init__(self, middlewares: list[Middleware] | None = None):
|
|
self.middlewares = middlewares or []
|
|
|
|
def add(self, mw: Middleware) -> "MiddlewareChain":
|
|
self.middlewares.append(mw)
|
|
return self
|
|
|
|
async def before(self, state: ReportState) -> ReportState:
|
|
for mw in self.middlewares:
|
|
if mw.enabled:
|
|
logger.debug(f"[middleware:before] {mw.name}")
|
|
state = await mw.before(state)
|
|
return state
|
|
|
|
async def after(self, state: ReportState) -> ReportState:
|
|
for mw in reversed(self.middlewares):
|
|
if mw.enabled:
|
|
logger.debug(f"[middleware:after] {mw.name}")
|
|
state = await mw.after(state)
|
|
return state
|