Files
20260327-c863ce53/app/data/sources/akshare_source.py
2026-04-25 19:25:22 +08:00

95 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AKShare data source — Chinese macro/industry data via open-source Python library.
Covers: GDP, CPI, PMI, industrial profit, trade balance, and 30+ data categories.
All data returned as Pandas DataFrames, converted to dicts for standardization.
"""
from __future__ import annotations
import logging
from typing import Any
from .base import DataSource, DataResult
logger = logging.getLogger(__name__)
# Map common data requests to AKShare function names
AKSHARE_ENDPOINTS = {
"gdp": "macro_china_gdp",
"cpi": "macro_china_cpi_monthly",
"ppi": "macro_china_ppi",
"pmi": "macro_china_pmi",
"industrial_profit": "macro_china_industrial_profit",
"trade_balance": "macro_china_trade_balance",
"money_supply": "macro_china_money_supply",
"fdi": "macro_china_fdi",
"real_estate": "macro_china_real_estate",
"retail_sales": "macro_china_consumer_goods_retail",
"fixed_asset": "macro_china_fai",
"unemployment": "macro_china_urban_unemployment",
# US macro
"us_gdp": "macro_usa_gdp_monthly",
"us_cpi": "macro_usa_cpi_monthly",
"us_unemployment": "macro_usa_unemployment_rate",
# Global
"global_gdp": "macro_global_gdp",
}
class AKShareSource(DataSource):
name = "akshare"
description = "中国宏观经济/行业数据免费开源封装统计局等30+数据源)"
def supports(self, data_type: str, country: str | None = None) -> bool:
return data_type in ("macro", "industry", "general")
async def fetch(
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
) -> DataResult:
try:
import akshare as ak
except ImportError:
return DataResult(source=self.name, error="akshare not installed (pip install akshare)")
# Try to match query to a known endpoint
endpoint_name = kwargs.get("endpoint")
if not endpoint_name:
query_lower = query.lower()
for key, func_name in AKSHARE_ENDPOINTS.items():
if key in query_lower:
endpoint_name = func_name
break
if not endpoint_name:
return DataResult(source=self.name, data=None, error=f"No matching AKShare endpoint for: {query}")
try:
func = getattr(ak, endpoint_name, None)
if not func:
return DataResult(source=self.name, error=f"AKShare function not found: {endpoint_name}")
logger.info(f"[akshare] calling ak.{endpoint_name}()")
df = func()
# Convert to dict for serialization
# Take last N rows for recent data
limit = kwargs.get("limit", 20)
recent = df.tail(limit)
return DataResult(
source=self.name,
data={
"columns": list(recent.columns),
"records": recent.to_dict(orient="records"),
"total_rows": len(df),
"returned_rows": len(recent),
},
metadata={
"endpoint": endpoint_name,
"description": f"AKShare {endpoint_name}",
"format": "tabular",
},
)
except Exception as e:
return DataResult(source=self.name, error=f"AKShare call failed: {e}")