Files
20250920-e194e889/database.py
2026-04-25 19:21:28 +08:00

226 lines
7.5 KiB
Python

"""
数据库管理模块
"""
import sqlite3
import pandas as pd
from datetime import datetime
from typing import Dict, List, Optional
import json
class StockDatabase:
def __init__(self, db_path: str = "stock_analysis.db"):
self.db_path = db_path
self.init_database()
def init_database(self):
"""初始化数据库表结构"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 公司基本信息表
cursor.execute('''
CREATE TABLE IF NOT EXISTS companies (
symbol TEXT PRIMARY KEY,
name TEXT NOT NULL,
sector TEXT,
industry TEXT,
market_cap REAL,
employees INTEGER,
website TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''')
# 股价数据表
cursor.execute('''
CREATE TABLE IF NOT EXISTS stock_prices (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
date DATE NOT NULL,
open REAL,
high REAL,
low REAL,
close REAL,
volume INTEGER,
adjusted_close REAL,
FOREIGN KEY (symbol) REFERENCES companies (symbol),
UNIQUE(symbol, date)
)
''')
# 财务数据表
cursor.execute('''
CREATE TABLE IF NOT EXISTS financial_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
period TEXT NOT NULL,
year INTEGER NOT NULL,
quarter INTEGER,
revenue REAL,
net_income REAL,
total_assets REAL,
total_liabilities REAL,
shareholders_equity REAL,
cash REAL,
debt REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (symbol) REFERENCES companies (symbol),
UNIQUE(symbol, period, year, quarter)
)
''')
# 分析结果表
cursor.execute('''
CREATE TABLE IF NOT EXISTS analysis_results (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
analysis_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
valuation_score REAL,
financial_health_score REAL,
growth_score REAL,
risk_score REAL,
overall_score REAL,
recommendation TEXT,
analysis_data TEXT, -- JSON格式存储详细分析数据
FOREIGN KEY (symbol) REFERENCES companies (symbol)
)
''')
conn.commit()
conn.close()
def save_company_info(self, symbol: str, company_data: Dict):
"""保存公司基本信息"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO companies
(symbol, name, sector, industry, market_cap, employees, website, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
symbol,
company_data.get('name', ''),
company_data.get('sector', ''),
company_data.get('industry', ''),
company_data.get('market_cap', 0),
company_data.get('employees', 0),
company_data.get('website', ''),
datetime.now()
))
conn.commit()
conn.close()
def save_stock_prices(self, symbol: str, price_data: pd.DataFrame):
"""保存股价数据"""
conn = sqlite3.connect(self.db_path)
price_data['symbol'] = symbol
price_data.to_sql('stock_prices', conn, if_exists='append', index=False)
conn.close()
def save_financial_data(self, symbol: str, financial_data: Dict):
"""保存财务数据"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
for period, data in financial_data.items():
cursor.execute('''
INSERT OR REPLACE INTO financial_data
(symbol, period, year, quarter, revenue, net_income, total_assets,
total_liabilities, shareholders_equity, cash, debt)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
symbol,
period,
data.get('year', 0),
data.get('quarter', 0),
data.get('revenue', 0),
data.get('net_income', 0),
data.get('total_assets', 0),
data.get('total_liabilities', 0),
data.get('shareholders_equity', 0),
data.get('cash', 0),
data.get('debt', 0)
))
conn.commit()
conn.close()
def save_analysis_result(self, symbol: str, analysis_data: Dict):
"""保存分析结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
INSERT INTO analysis_results
(symbol, valuation_score, financial_health_score, growth_score,
risk_score, overall_score, recommendation, analysis_data)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
symbol,
analysis_data.get('valuation_score', 0),
analysis_data.get('financial_health_score', 0),
analysis_data.get('growth_score', 0),
analysis_data.get('risk_score', 0),
analysis_data.get('overall_score', 0),
analysis_data.get('recommendation', ''),
json.dumps(analysis_data.get('detailed_analysis', {}))
))
conn.commit()
conn.close()
def get_company_info(self, symbol: str) -> Optional[Dict]:
"""获取公司信息"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('SELECT * FROM companies WHERE symbol = ?', (symbol,))
result = cursor.fetchone()
conn.close()
if result:
return {
'symbol': result[0],
'name': result[1],
'sector': result[2],
'industry': result[3],
'market_cap': result[4],
'employees': result[5],
'website': result[6]
}
return None
def get_latest_analysis(self, symbol: str) -> Optional[Dict]:
"""获取最新分析结果"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute('''
SELECT * FROM analysis_results
WHERE symbol = ?
ORDER BY analysis_date DESC
LIMIT 1
''', (symbol,))
result = cursor.fetchone()
conn.close()
if result:
return {
'symbol': result[1],
'analysis_date': result[2],
'valuation_score': result[3],
'financial_health_score': result[4],
'growth_score': result[5],
'risk_score': result[6],
'overall_score': result[7],
'recommendation': result[8],
'analysis_data': json.loads(result[9]) if result[9] else {}
}
return None