353 lines
12 KiB
Python
353 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
数据库操作类
|
|
"""
|
|
|
|
import sqlite3
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import List, Dict, Optional, Tuple
|
|
from pathlib import Path
|
|
|
|
from config import DATABASE_CONFIG, RSS_SOURCES
|
|
|
|
class DatabaseManager:
|
|
"""数据库管理类"""
|
|
|
|
def __init__(self):
|
|
self.db_type = DATABASE_CONFIG['type']
|
|
if self.db_type == 'sqlite':
|
|
self.db_path = DATABASE_CONFIG['sqlite']['path']
|
|
self.conn = None
|
|
self.logger = logging.getLogger(__name__)
|
|
self._init_database()
|
|
|
|
def _get_connection(self):
|
|
"""获取数据库连接"""
|
|
if self.db_type == 'sqlite':
|
|
if not self.conn:
|
|
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
|
self.conn.row_factory = sqlite3.Row
|
|
return self.conn
|
|
# 后续可扩展MySQL/PostgreSQL
|
|
|
|
def _init_database(self):
|
|
"""初始化数据库"""
|
|
if not Path(self.db_path).exists():
|
|
self._create_tables()
|
|
self._insert_initial_data()
|
|
|
|
def _create_tables(self):
|
|
"""创建数据库表"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 读取SQL文件并执行
|
|
sql_file = Path(__file__).parent / 'database_schema.sql'
|
|
if sql_file.exists():
|
|
with open(sql_file, 'r', encoding='utf-8') as f:
|
|
sql_script = f.read()
|
|
cursor.executescript(sql_script)
|
|
|
|
conn.commit()
|
|
self.logger.info("数据库表创建完成")
|
|
|
|
def _insert_initial_data(self):
|
|
"""插入初始RSS源数据"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 获取行业ID映射
|
|
cursor.execute("SELECT id, name_en FROM industries")
|
|
industry_map = {row['name_en']: row['id'] for row in cursor.fetchall()}
|
|
|
|
# 插入RSS源
|
|
for industry, sources in RSS_SOURCES.items():
|
|
if industry in industry_map:
|
|
industry_id = industry_map[industry]
|
|
for source in sources:
|
|
cursor.execute("""
|
|
INSERT OR IGNORE INTO rss_sources
|
|
(industry_id, source_name, source_url, source_type, authority_level, language)
|
|
VALUES (?, ?, ?, 'rss', ?, ?)
|
|
""", (industry_id, source['name'], source['url'],
|
|
source['authority_level'], source['language']))
|
|
|
|
conn.commit()
|
|
self.logger.info("初始RSS源数据插入完成")
|
|
|
|
def get_industries(self) -> List[Dict]:
|
|
"""获取所有行业"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM industries ORDER BY name_en")
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def get_rss_sources(self, industry_id: Optional[int] = None,
|
|
active_only: bool = True) -> List[Dict]:
|
|
"""获取RSS源"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
query = "SELECT * FROM rss_sources WHERE 1=1"
|
|
params = []
|
|
|
|
if industry_id:
|
|
query += " AND industry_id = ?"
|
|
params.append(industry_id)
|
|
|
|
if active_only:
|
|
query += " AND is_active = 1"
|
|
|
|
query += " ORDER BY authority_level, source_name"
|
|
|
|
cursor.execute(query, params)
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def save_article(self, article_data: Dict) -> Optional[int]:
|
|
"""保存文章"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 生成文章hash防重复
|
|
content_hash = hashlib.sha256(
|
|
f"{article_data['title']}{article_data['original_url']}".encode()
|
|
).hexdigest()
|
|
|
|
# 检查是否已存在
|
|
cursor.execute("SELECT id FROM articles WHERE article_hash = ?", (content_hash,))
|
|
if cursor.fetchone():
|
|
return None # 文章已存在
|
|
|
|
try:
|
|
cursor.execute("""
|
|
INSERT INTO articles
|
|
(title, content, summary, author, source_id, original_url,
|
|
published_date, language, keywords, article_hash)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
article_data['title'],
|
|
article_data.get('content', ''),
|
|
article_data.get('summary', ''),
|
|
article_data.get('author', ''),
|
|
article_data['source_id'],
|
|
article_data['original_url'],
|
|
article_data.get('published_date'),
|
|
article_data.get('language', 'en'),
|
|
json.dumps(article_data.get('keywords', [])),
|
|
content_hash
|
|
))
|
|
|
|
article_id = cursor.lastrowid
|
|
conn.commit()
|
|
self.logger.debug(f"保存文章: {article_data['title']}")
|
|
return article_id
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"保存文章失败: {e}")
|
|
conn.rollback()
|
|
return None
|
|
|
|
def create_search_log(self, keywords: str, industry_id: Optional[int] = None,
|
|
language: str = 'en', user_ip: str = '') -> int:
|
|
"""创建搜索记录"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO search_logs (keywords, industry_id, language, user_ip)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (keywords, industry_id, language, user_ip))
|
|
|
|
search_log_id = cursor.lastrowid
|
|
conn.commit()
|
|
return search_log_id
|
|
|
|
def save_search_results(self, search_log_id: int, articles: List[Dict]):
|
|
"""保存搜索结果"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
for rank, article in enumerate(articles, 1):
|
|
cursor.execute("""
|
|
INSERT INTO search_results
|
|
(search_log_id, article_id, relevance_score, rank_position)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (search_log_id, article['id'], article.get('relevance_score', 0), rank))
|
|
|
|
# 更新搜索记录的结果数量
|
|
cursor.execute("""
|
|
UPDATE search_logs SET results_count = ? WHERE id = ?
|
|
""", (len(articles), search_log_id))
|
|
|
|
conn.commit()
|
|
|
|
def search_articles(self, keywords: List[str], industry_id: Optional[int] = None,
|
|
language: Optional[str] = None, limit: int = 50,
|
|
days_back: int = 30) -> List[Dict]:
|
|
"""搜索文章"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
# 构建搜索查询
|
|
query = """
|
|
SELECT a.*, rs.source_name, rs.authority_level, i.name_cn as industry_name
|
|
FROM articles a
|
|
JOIN rss_sources rs ON a.source_id = rs.id
|
|
JOIN industries i ON rs.industry_id = i.id
|
|
WHERE 1=1
|
|
"""
|
|
params = []
|
|
|
|
# 时间范围过滤
|
|
if days_back > 0:
|
|
date_threshold = datetime.now() - timedelta(days=days_back)
|
|
query += " AND a.published_date >= ?"
|
|
params.append(date_threshold)
|
|
|
|
# 行业过滤
|
|
if industry_id:
|
|
query += " AND rs.industry_id = ?"
|
|
params.append(industry_id)
|
|
|
|
# 语言过滤
|
|
if language:
|
|
query += " AND a.language = ?"
|
|
params.append(language)
|
|
|
|
# 关键词搜索
|
|
if keywords:
|
|
keyword_conditions = []
|
|
for keyword in keywords:
|
|
keyword_conditions.append("(a.title LIKE ? OR a.content LIKE ?)")
|
|
params.extend([f"%{keyword}%", f"%{keyword}%"])
|
|
|
|
query += f" AND ({' OR '.join(keyword_conditions)})"
|
|
|
|
# 排序和限制
|
|
query += " ORDER BY rs.authority_level ASC, a.published_date DESC LIMIT ?"
|
|
params.append(limit)
|
|
|
|
cursor.execute(query, params)
|
|
results = [dict(row) for row in cursor.fetchall()]
|
|
|
|
# 计算相关性分数
|
|
for result in results:
|
|
result['relevance_score'] = self._calculate_relevance(result, keywords)
|
|
|
|
# 按相关性和权威性排序
|
|
results.sort(key=lambda x: (x['authority_level'], -x['relevance_score']))
|
|
|
|
return results
|
|
|
|
def _calculate_relevance(self, article: Dict, keywords: List[str]) -> float:
|
|
"""计算文章相关性分数"""
|
|
if not keywords:
|
|
return 1.0
|
|
|
|
title = article.get('title', '').lower()
|
|
content = article.get('content', '').lower()
|
|
|
|
score = 0.0
|
|
for keyword in keywords:
|
|
keyword = keyword.lower()
|
|
# 标题匹配权重更高
|
|
title_matches = title.count(keyword)
|
|
content_matches = content.count(keyword)
|
|
|
|
score += title_matches * 2.0 + content_matches * 0.5
|
|
|
|
# 根据信源权威级别调整分数
|
|
authority_bonus = (4 - article.get('authority_level', 4)) * 0.1
|
|
score += authority_bonus
|
|
|
|
return min(score, 10.0) # 限制最高分数
|
|
|
|
def get_search_history(self, limit: int = 20) -> List[Dict]:
|
|
"""获取搜索历史"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT sl.*, i.name_cn as industry_name
|
|
FROM search_logs sl
|
|
LEFT JOIN industries i ON sl.industry_id = i.id
|
|
ORDER BY sl.search_time DESC
|
|
LIMIT ?
|
|
""", (limit,))
|
|
|
|
return [dict(row) for row in cursor.fetchall()]
|
|
|
|
def save_exported_doc(self, search_log_id: int, filename: str,
|
|
file_path: str, articles_count: int) -> int:
|
|
"""保存导出文档记录"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
INSERT INTO exported_docs
|
|
(search_log_id, filename, file_path, articles_count)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (search_log_id, filename, file_path, articles_count))
|
|
|
|
doc_id = cursor.lastrowid
|
|
conn.commit()
|
|
return doc_id
|
|
|
|
def update_rss_source_check_time(self, source_id: int):
|
|
"""更新RSS源检查时间"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
UPDATE rss_sources SET last_checked = CURRENT_TIMESTAMP WHERE id = ?
|
|
""", (source_id,))
|
|
|
|
conn.commit()
|
|
|
|
def get_statistics(self) -> Dict:
|
|
"""获取系统统计信息"""
|
|
conn = self._get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
stats = {}
|
|
|
|
# 文章总数
|
|
cursor.execute("SELECT COUNT(*) as count FROM articles")
|
|
stats['total_articles'] = cursor.fetchone()['count']
|
|
|
|
# 今日新增文章
|
|
cursor.execute("""
|
|
SELECT COUNT(*) as count FROM articles
|
|
WHERE DATE(scraped_date) = DATE('now')
|
|
""")
|
|
stats['today_articles'] = cursor.fetchone()['count']
|
|
|
|
# 搜索总次数
|
|
cursor.execute("SELECT COUNT(*) as count FROM search_logs")
|
|
stats['total_searches'] = cursor.fetchone()['count']
|
|
|
|
# 活跃RSS源数量
|
|
cursor.execute("SELECT COUNT(*) as count FROM rss_sources WHERE is_active = 1")
|
|
stats['active_sources'] = cursor.fetchone()['count']
|
|
|
|
# 按行业统计文章数
|
|
cursor.execute("""
|
|
SELECT i.name_cn, COUNT(a.id) as count
|
|
FROM industries i
|
|
LEFT JOIN rss_sources rs ON i.id = rs.industry_id
|
|
LEFT JOIN articles a ON rs.id = a.source_id
|
|
GROUP BY i.id, i.name_cn
|
|
ORDER BY count DESC
|
|
""")
|
|
stats['articles_by_industry'] = [dict(row) for row in cursor.fetchall()]
|
|
|
|
return stats
|
|
|
|
def close(self):
|
|
"""关闭数据库连接"""
|
|
if self.conn:
|
|
self.conn.close()
|
|
self.conn = None |