Files
2026-04-25 19:21:03 +08:00

353 lines
12 KiB
Python

# -*- coding: utf-8 -*-
"""
数据库操作类
"""
import sqlite3
import hashlib
import json
import logging
from datetime import datetime, timedelta
from typing import List, Dict, Optional, Tuple
from pathlib import Path
from config import DATABASE_CONFIG, RSS_SOURCES
class DatabaseManager:
"""数据库管理类"""
def __init__(self):
self.db_type = DATABASE_CONFIG['type']
if self.db_type == 'sqlite':
self.db_path = DATABASE_CONFIG['sqlite']['path']
self.conn = None
self.logger = logging.getLogger(__name__)
self._init_database()
def _get_connection(self):
"""获取数据库连接"""
if self.db_type == 'sqlite':
if not self.conn:
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
return self.conn
# 后续可扩展MySQL/PostgreSQL
def _init_database(self):
"""初始化数据库"""
if not Path(self.db_path).exists():
self._create_tables()
self._insert_initial_data()
def _create_tables(self):
"""创建数据库表"""
conn = self._get_connection()
cursor = conn.cursor()
# 读取SQL文件并执行
sql_file = Path(__file__).parent / 'database_schema.sql'
if sql_file.exists():
with open(sql_file, 'r', encoding='utf-8') as f:
sql_script = f.read()
cursor.executescript(sql_script)
conn.commit()
self.logger.info("数据库表创建完成")
def _insert_initial_data(self):
"""插入初始RSS源数据"""
conn = self._get_connection()
cursor = conn.cursor()
# 获取行业ID映射
cursor.execute("SELECT id, name_en FROM industries")
industry_map = {row['name_en']: row['id'] for row in cursor.fetchall()}
# 插入RSS源
for industry, sources in RSS_SOURCES.items():
if industry in industry_map:
industry_id = industry_map[industry]
for source in sources:
cursor.execute("""
INSERT OR IGNORE INTO rss_sources
(industry_id, source_name, source_url, source_type, authority_level, language)
VALUES (?, ?, ?, 'rss', ?, ?)
""", (industry_id, source['name'], source['url'],
source['authority_level'], source['language']))
conn.commit()
self.logger.info("初始RSS源数据插入完成")
def get_industries(self) -> List[Dict]:
"""获取所有行业"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT * FROM industries ORDER BY name_en")
return [dict(row) for row in cursor.fetchall()]
def get_rss_sources(self, industry_id: Optional[int] = None,
active_only: bool = True) -> List[Dict]:
"""获取RSS源"""
conn = self._get_connection()
cursor = conn.cursor()
query = "SELECT * FROM rss_sources WHERE 1=1"
params = []
if industry_id:
query += " AND industry_id = ?"
params.append(industry_id)
if active_only:
query += " AND is_active = 1"
query += " ORDER BY authority_level, source_name"
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def save_article(self, article_data: Dict) -> Optional[int]:
"""保存文章"""
conn = self._get_connection()
cursor = conn.cursor()
# 生成文章hash防重复
content_hash = hashlib.sha256(
f"{article_data['title']}{article_data['original_url']}".encode()
).hexdigest()
# 检查是否已存在
cursor.execute("SELECT id FROM articles WHERE article_hash = ?", (content_hash,))
if cursor.fetchone():
return None # 文章已存在
try:
cursor.execute("""
INSERT INTO articles
(title, content, summary, author, source_id, original_url,
published_date, language, keywords, article_hash)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
article_data['title'],
article_data.get('content', ''),
article_data.get('summary', ''),
article_data.get('author', ''),
article_data['source_id'],
article_data['original_url'],
article_data.get('published_date'),
article_data.get('language', 'en'),
json.dumps(article_data.get('keywords', [])),
content_hash
))
article_id = cursor.lastrowid
conn.commit()
self.logger.debug(f"保存文章: {article_data['title']}")
return article_id
except Exception as e:
self.logger.error(f"保存文章失败: {e}")
conn.rollback()
return None
def create_search_log(self, keywords: str, industry_id: Optional[int] = None,
language: str = 'en', user_ip: str = '') -> int:
"""创建搜索记录"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT INTO search_logs (keywords, industry_id, language, user_ip)
VALUES (?, ?, ?, ?)
""", (keywords, industry_id, language, user_ip))
search_log_id = cursor.lastrowid
conn.commit()
return search_log_id
def save_search_results(self, search_log_id: int, articles: List[Dict]):
"""保存搜索结果"""
conn = self._get_connection()
cursor = conn.cursor()
for rank, article in enumerate(articles, 1):
cursor.execute("""
INSERT INTO search_results
(search_log_id, article_id, relevance_score, rank_position)
VALUES (?, ?, ?, ?)
""", (search_log_id, article['id'], article.get('relevance_score', 0), rank))
# 更新搜索记录的结果数量
cursor.execute("""
UPDATE search_logs SET results_count = ? WHERE id = ?
""", (len(articles), search_log_id))
conn.commit()
def search_articles(self, keywords: List[str], industry_id: Optional[int] = None,
language: Optional[str] = None, limit: int = 50,
days_back: int = 30) -> List[Dict]:
"""搜索文章"""
conn = self._get_connection()
cursor = conn.cursor()
# 构建搜索查询
query = """
SELECT a.*, rs.source_name, rs.authority_level, i.name_cn as industry_name
FROM articles a
JOIN rss_sources rs ON a.source_id = rs.id
JOIN industries i ON rs.industry_id = i.id
WHERE 1=1
"""
params = []
# 时间范围过滤
if days_back > 0:
date_threshold = datetime.now() - timedelta(days=days_back)
query += " AND a.published_date >= ?"
params.append(date_threshold)
# 行业过滤
if industry_id:
query += " AND rs.industry_id = ?"
params.append(industry_id)
# 语言过滤
if language:
query += " AND a.language = ?"
params.append(language)
# 关键词搜索
if keywords:
keyword_conditions = []
for keyword in keywords:
keyword_conditions.append("(a.title LIKE ? OR a.content LIKE ?)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
query += f" AND ({' OR '.join(keyword_conditions)})"
# 排序和限制
query += " ORDER BY rs.authority_level ASC, a.published_date DESC LIMIT ?"
params.append(limit)
cursor.execute(query, params)
results = [dict(row) for row in cursor.fetchall()]
# 计算相关性分数
for result in results:
result['relevance_score'] = self._calculate_relevance(result, keywords)
# 按相关性和权威性排序
results.sort(key=lambda x: (x['authority_level'], -x['relevance_score']))
return results
def _calculate_relevance(self, article: Dict, keywords: List[str]) -> float:
"""计算文章相关性分数"""
if not keywords:
return 1.0
title = article.get('title', '').lower()
content = article.get('content', '').lower()
score = 0.0
for keyword in keywords:
keyword = keyword.lower()
# 标题匹配权重更高
title_matches = title.count(keyword)
content_matches = content.count(keyword)
score += title_matches * 2.0 + content_matches * 0.5
# 根据信源权威级别调整分数
authority_bonus = (4 - article.get('authority_level', 4)) * 0.1
score += authority_bonus
return min(score, 10.0) # 限制最高分数
def get_search_history(self, limit: int = 20) -> List[Dict]:
"""获取搜索历史"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
SELECT sl.*, i.name_cn as industry_name
FROM search_logs sl
LEFT JOIN industries i ON sl.industry_id = i.id
ORDER BY sl.search_time DESC
LIMIT ?
""", (limit,))
return [dict(row) for row in cursor.fetchall()]
def save_exported_doc(self, search_log_id: int, filename: str,
file_path: str, articles_count: int) -> int:
"""保存导出文档记录"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
INSERT INTO exported_docs
(search_log_id, filename, file_path, articles_count)
VALUES (?, ?, ?, ?)
""", (search_log_id, filename, file_path, articles_count))
doc_id = cursor.lastrowid
conn.commit()
return doc_id
def update_rss_source_check_time(self, source_id: int):
"""更新RSS源检查时间"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("""
UPDATE rss_sources SET last_checked = CURRENT_TIMESTAMP WHERE id = ?
""", (source_id,))
conn.commit()
def get_statistics(self) -> Dict:
"""获取系统统计信息"""
conn = self._get_connection()
cursor = conn.cursor()
stats = {}
# 文章总数
cursor.execute("SELECT COUNT(*) as count FROM articles")
stats['total_articles'] = cursor.fetchone()['count']
# 今日新增文章
cursor.execute("""
SELECT COUNT(*) as count FROM articles
WHERE DATE(scraped_date) = DATE('now')
""")
stats['today_articles'] = cursor.fetchone()['count']
# 搜索总次数
cursor.execute("SELECT COUNT(*) as count FROM search_logs")
stats['total_searches'] = cursor.fetchone()['count']
# 活跃RSS源数量
cursor.execute("SELECT COUNT(*) as count FROM rss_sources WHERE is_active = 1")
stats['active_sources'] = cursor.fetchone()['count']
# 按行业统计文章数
cursor.execute("""
SELECT i.name_cn, COUNT(a.id) as count
FROM industries i
LEFT JOIN rss_sources rs ON i.id = rs.industry_id
LEFT JOIN articles a ON rs.id = a.source_id
GROUP BY i.id, i.name_cn
ORDER BY count DESC
""")
stats['articles_by_industry'] = [dict(row) for row in cursor.fetchall()]
return stats
def close(self):
"""关闭数据库连接"""
if self.conn:
self.conn.close()
self.conn = None