""" 数据库管理模块 """ import sqlite3 import pandas as pd from datetime import datetime from typing import Dict, List, Optional import json class StockDatabase: def __init__(self, db_path: str = "stock_analysis.db"): self.db_path = db_path self.init_database() def init_database(self): """初始化数据库表结构""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # 公司基本信息表 cursor.execute(''' CREATE TABLE IF NOT EXISTS companies ( symbol TEXT PRIMARY KEY, name TEXT NOT NULL, sector TEXT, industry TEXT, market_cap REAL, employees INTEGER, website TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ''') # 股价数据表 cursor.execute(''' CREATE TABLE IF NOT EXISTS stock_prices ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL, date DATE NOT NULL, open REAL, high REAL, low REAL, close REAL, volume INTEGER, adjusted_close REAL, FOREIGN KEY (symbol) REFERENCES companies (symbol), UNIQUE(symbol, date) ) ''') # 财务数据表 cursor.execute(''' CREATE TABLE IF NOT EXISTS financial_data ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL, period TEXT NOT NULL, year INTEGER NOT NULL, quarter INTEGER, revenue REAL, net_income REAL, total_assets REAL, total_liabilities REAL, shareholders_equity REAL, cash REAL, debt REAL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (symbol) REFERENCES companies (symbol), UNIQUE(symbol, period, year, quarter) ) ''') # 分析结果表 cursor.execute(''' CREATE TABLE IF NOT EXISTS analysis_results ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL, analysis_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, valuation_score REAL, financial_health_score REAL, growth_score REAL, risk_score REAL, overall_score REAL, recommendation TEXT, analysis_data TEXT, -- JSON格式存储详细分析数据 FOREIGN KEY (symbol) REFERENCES companies (symbol) ) ''') conn.commit() conn.close() def save_company_info(self, symbol: str, company_data: Dict): """保存公司基本信息""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' INSERT OR REPLACE INTO companies (symbol, name, sector, industry, market_cap, employees, website, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( symbol, company_data.get('name', ''), company_data.get('sector', ''), company_data.get('industry', ''), company_data.get('market_cap', 0), company_data.get('employees', 0), company_data.get('website', ''), datetime.now() )) conn.commit() conn.close() def save_stock_prices(self, symbol: str, price_data: pd.DataFrame): """保存股价数据""" conn = sqlite3.connect(self.db_path) price_data['symbol'] = symbol price_data.to_sql('stock_prices', conn, if_exists='append', index=False) conn.close() def save_financial_data(self, symbol: str, financial_data: Dict): """保存财务数据""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() for period, data in financial_data.items(): cursor.execute(''' INSERT OR REPLACE INTO financial_data (symbol, period, year, quarter, revenue, net_income, total_assets, total_liabilities, shareholders_equity, cash, debt) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ''', ( symbol, period, data.get('year', 0), data.get('quarter', 0), data.get('revenue', 0), data.get('net_income', 0), data.get('total_assets', 0), data.get('total_liabilities', 0), data.get('shareholders_equity', 0), data.get('cash', 0), data.get('debt', 0) )) conn.commit() conn.close() def save_analysis_result(self, symbol: str, analysis_data: Dict): """保存分析结果""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' INSERT INTO analysis_results (symbol, valuation_score, financial_health_score, growth_score, risk_score, overall_score, recommendation, analysis_data) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( symbol, analysis_data.get('valuation_score', 0), analysis_data.get('financial_health_score', 0), analysis_data.get('growth_score', 0), analysis_data.get('risk_score', 0), analysis_data.get('overall_score', 0), analysis_data.get('recommendation', ''), json.dumps(analysis_data.get('detailed_analysis', {})) )) conn.commit() conn.close() def get_company_info(self, symbol: str) -> Optional[Dict]: """获取公司信息""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute('SELECT * FROM companies WHERE symbol = ?', (symbol,)) result = cursor.fetchone() conn.close() if result: return { 'symbol': result[0], 'name': result[1], 'sector': result[2], 'industry': result[3], 'market_cap': result[4], 'employees': result[5], 'website': result[6] } return None def get_latest_analysis(self, symbol: str) -> Optional[Dict]: """获取最新分析结果""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(''' SELECT * FROM analysis_results WHERE symbol = ? ORDER BY analysis_date DESC LIMIT 1 ''', (symbol,)) result = cursor.fetchone() conn.close() if result: return { 'symbol': result[1], 'analysis_date': result[2], 'valuation_score': result[3], 'financial_health_score': result[4], 'growth_score': result[5], 'risk_score': result[6], 'overall_score': result[7], 'recommendation': result[8], 'analysis_data': json.loads(result[9]) if result[9] else {} } return None