95 lines
3.3 KiB
Python
95 lines
3.3 KiB
Python
"""AKShare data source — Chinese macro/industry data via open-source Python library.
|
||
|
||
Covers: GDP, CPI, PMI, industrial profit, trade balance, and 30+ data categories.
|
||
All data returned as Pandas DataFrames, converted to dicts for standardization.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from .base import DataSource, DataResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Map common data requests to AKShare function names
|
||
AKSHARE_ENDPOINTS = {
|
||
"gdp": "macro_china_gdp",
|
||
"cpi": "macro_china_cpi_monthly",
|
||
"ppi": "macro_china_ppi",
|
||
"pmi": "macro_china_pmi",
|
||
"industrial_profit": "macro_china_industrial_profit",
|
||
"trade_balance": "macro_china_trade_balance",
|
||
"money_supply": "macro_china_money_supply",
|
||
"fdi": "macro_china_fdi",
|
||
"real_estate": "macro_china_real_estate",
|
||
"retail_sales": "macro_china_consumer_goods_retail",
|
||
"fixed_asset": "macro_china_fai",
|
||
"unemployment": "macro_china_urban_unemployment",
|
||
# US macro
|
||
"us_gdp": "macro_usa_gdp_monthly",
|
||
"us_cpi": "macro_usa_cpi_monthly",
|
||
"us_unemployment": "macro_usa_unemployment_rate",
|
||
# Global
|
||
"global_gdp": "macro_global_gdp",
|
||
}
|
||
|
||
|
||
class AKShareSource(DataSource):
|
||
name = "akshare"
|
||
description = "中国宏观经济/行业数据(免费开源,封装统计局等30+数据源)"
|
||
|
||
def supports(self, data_type: str, country: str | None = None) -> bool:
|
||
return data_type in ("macro", "industry", "general")
|
||
|
||
async def fetch(
|
||
self, query: str, *, data_type: str = "general", country: str | None = None, **kwargs,
|
||
) -> DataResult:
|
||
try:
|
||
import akshare as ak
|
||
except ImportError:
|
||
return DataResult(source=self.name, error="akshare not installed (pip install akshare)")
|
||
|
||
# Try to match query to a known endpoint
|
||
endpoint_name = kwargs.get("endpoint")
|
||
if not endpoint_name:
|
||
query_lower = query.lower()
|
||
for key, func_name in AKSHARE_ENDPOINTS.items():
|
||
if key in query_lower:
|
||
endpoint_name = func_name
|
||
break
|
||
|
||
if not endpoint_name:
|
||
return DataResult(source=self.name, data=None, error=f"No matching AKShare endpoint for: {query}")
|
||
|
||
try:
|
||
func = getattr(ak, endpoint_name, None)
|
||
if not func:
|
||
return DataResult(source=self.name, error=f"AKShare function not found: {endpoint_name}")
|
||
|
||
logger.info(f"[akshare] calling ak.{endpoint_name}()")
|
||
df = func()
|
||
|
||
# Convert to dict for serialization
|
||
# Take last N rows for recent data
|
||
limit = kwargs.get("limit", 20)
|
||
recent = df.tail(limit)
|
||
|
||
return DataResult(
|
||
source=self.name,
|
||
data={
|
||
"columns": list(recent.columns),
|
||
"records": recent.to_dict(orient="records"),
|
||
"total_rows": len(df),
|
||
"returned_rows": len(recent),
|
||
},
|
||
metadata={
|
||
"endpoint": endpoint_name,
|
||
"description": f"AKShare {endpoint_name}",
|
||
"format": "tabular",
|
||
},
|
||
)
|
||
except Exception as e:
|
||
return DataResult(source=self.name, error=f"AKShare call failed: {e}")
|