Files
20250715-66bfff96/代码实现/search_engine.py
2026-04-25 19:21:03 +08:00

461 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""
搜索引擎主类
"""
import requests
import logging
import time
from typing import List, Dict, Optional, Tuple
from datetime import datetime, timedelta
from database import DatabaseManager
from config import API_CONFIG, SEARCH_CONFIG
class SearchEngine:
"""智能搜索引擎"""
def __init__(self):
self.db = DatabaseManager()
self.logger = logging.getLogger(__name__)
self.newsapi_key = API_CONFIG['newsapi']['key']
self.twitter_token = API_CONFIG['twitter']['bearer_token']
self.alpha_vantage_key = API_CONFIG['alpha_vantage']['key']
def search(self, query: str, industry: str = None,
language: str = None, user_ip: str = '') -> Dict:
"""执行搜索"""
start_time = time.time()
# 解析查询参数
search_params = self._parse_query(query, industry, language)
keywords = search_params['keywords']
industry_id = search_params['industry_id']
detected_language = search_params['language']
self.logger.info(f"开始搜索: {keywords}, 行业: {industry}, 语言: {detected_language}")
# 创建搜索记录
search_log_id = self.db.create_search_log(
keywords=' '.join(keywords),
industry_id=industry_id,
language=detected_language,
user_ip=user_ip
)
try:
# 多源搜索
all_results = []
# 1. 搜索本地数据库
db_results = self._search_database(keywords, industry_id, detected_language)
all_results.extend(db_results)
self.logger.info(f"数据库搜索结果: {len(db_results)}")
# 2. NewsAPI搜索如果有API密钥
if self.newsapi_key and detected_language == 'en':
news_results = self._search_newsapi(keywords, industry)
all_results.extend(news_results)
self.logger.info(f"NewsAPI搜索结果: {len(news_results)}")
# 3. 金融数据API搜索金融行业
if industry == 'finance' and self.alpha_vantage_key:
finance_results = self._search_financial_data(keywords)
all_results.extend(finance_results)
self.logger.info(f"金融数据搜索结果: {len(finance_results)}")
# 结果去重和排序
final_results = self._process_results(all_results, keywords)
# 保存搜索结果
if final_results:
self.db.save_search_results(search_log_id, final_results)
search_time = time.time() - start_time
return {
'success': True,
'search_log_id': search_log_id,
'query': query,
'keywords': keywords,
'industry': industry,
'language': detected_language,
'results': final_results,
'total_count': len(final_results),
'search_time': round(search_time, 2),
'sources_searched': self._get_sources_info(industry_id)
}
except Exception as e:
self.logger.error(f"搜索过程出错: {e}")
return {
'success': False,
'error': str(e),
'search_log_id': search_log_id,
'query': query
}
def _parse_query(self, query: str, industry: str = None,
language: str = None) -> Dict:
"""解析搜索查询"""
# 提取关键词
keywords = self._extract_keywords(query)
# 检测语言
if not language:
language = self._detect_language(query)
# 获取行业ID
industry_id = None
if industry:
industries = self.db.get_industries()
for ind in industries:
if ind['name_en'] == industry:
industry_id = ind['id']
break
return {
'keywords': keywords,
'industry_id': industry_id,
'language': language
}
def _extract_keywords(self, query: str) -> List[str]:
"""提取搜索关键词"""
import re
# 基础关键词提取
words = re.findall(r'\b\w+\b', query.lower())
# 过滤停用词
stop_words = {
'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before',
'after', 'above', 'below', 'up', 'down', 'out', 'off', 'over', 'under',
'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where',
'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most',
'other', 'some', 'such', 'only', 'own', 'same', 'so', 'than', 'too',
'very', 'can', 'will', 'just', 'should', 'now', 'what', 'news',
'latest', 'recent', 'update', 'today', 'yesterday'
}
keywords = [word for word in words if len(word) > 2 and word not in stop_words]
# 保留原始查询中的重要短语
phrases = self._extract_phrases(query)
keywords.extend(phrases)
return list(set(keywords)) # 去重
def _extract_phrases(self, query: str) -> List[str]:
"""提取重要短语"""
import re
# 提取引号内的短语
quoted_phrases = re.findall(r'"([^"]*)"', query)
# 提取常见的技术术语和公司名
phrases = []
# 技术术语模式
tech_patterns = [
r'\b[A-Z]{2,}\b', # 大写缩写 (AI, API, GDP)
r'\b\w+\.\w+\b', # 域名格式
r'\b\w+-\w+\b', # 连字符词组
]
for pattern in tech_patterns:
matches = re.findall(pattern, query)
phrases.extend(matches)
phrases.extend(quoted_phrases)
return phrases
def _detect_language(self, query: str) -> str:
"""检测查询语言"""
# 检查是否包含中文特定关键词
china_keywords = SEARCH_CONFIG['keywords_for_china']
for keyword in china_keywords:
if keyword in query:
return 'cn'
# 检查是否包含中文字符
import re
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', query)
if chinese_chars:
return 'cn'
return SEARCH_CONFIG['default_language']
def _search_database(self, keywords: List[str], industry_id: Optional[int],
language: str) -> List[Dict]:
"""搜索本地数据库"""
return self.db.search_articles(
keywords=keywords,
industry_id=industry_id,
language=language if language != 'cn' else None,
limit=SEARCH_CONFIG['max_results_per_source']
)
def _search_newsapi(self, keywords: List[str], industry: str = None) -> List[Dict]:
"""使用NewsAPI搜索"""
if not self.newsapi_key:
return []
try:
url = f"{API_CONFIG['newsapi']['base_url']}everything"
# 构建查询字符串
query_str = ' AND '.join(keywords[:5]) # 限制关键词数量
params = {
'q': query_str,
'apiKey': self.newsapi_key,
'language': 'en',
'sortBy': 'relevancy',
'pageSize': 20,
'from': (datetime.now() - timedelta(days=30)).isoformat()
}
# 添加行业相关域名
if industry:
domains = self._get_industry_domains(industry)
if domains:
params['domains'] = ','.join(domains)
response = requests.get(url, params=params, timeout=30)
response.raise_for_status()
data = response.json()
articles = []
for article in data.get('articles', []):
processed_article = {
'id': f"newsapi_{hash(article['url'])}",
'title': article['title'],
'content': article.get('description', ''),
'summary': article.get('description', ''),
'author': article.get('author', ''),
'original_url': article['url'],
'published_date': self._parse_date(article.get('publishedAt')),
'source_name': article['source']['name'],
'authority_level': 2, # 默认主流媒体级别
'language': 'en',
'relevance_score': 0.8 # NewsAPI结果相关性较高
}
articles.append(processed_article)
self.logger.info(f"NewsAPI返回 {len(articles)} 条结果")
return articles
except Exception as e:
self.logger.error(f"NewsAPI搜索失败: {e}")
return []
def _search_financial_data(self, keywords: List[str]) -> List[Dict]:
"""搜索金融数据"""
if not self.alpha_vantage_key:
return []
try:
# 检查关键词是否包含股票代码
stock_symbols = self._extract_stock_symbols(keywords)
if not stock_symbols:
return []
articles = []
for symbol in stock_symbols[:3]: # 限制查询数量
data = self._get_stock_news(symbol)
if data:
articles.extend(data)
return articles
except Exception as e:
self.logger.error(f"金融数据搜索失败: {e}")
return []
def _extract_stock_symbols(self, keywords: List[str]) -> List[str]:
"""提取股票代码"""
import re
symbols = []
for keyword in keywords:
# 检查是否为股票代码格式
if re.match(r'^[A-Z]{1,5}$', keyword.upper()):
symbols.append(keyword.upper())
# 添加一些常见公司的股票代码映射
company_symbols = {
'apple': 'AAPL', 'microsoft': 'MSFT', 'google': 'GOOGL',
'amazon': 'AMZN', 'tesla': 'TSLA', 'meta': 'META',
'nvidia': 'NVDA', 'intel': 'INTC', 'amd': 'AMD'
}
for keyword in keywords:
if keyword.lower() in company_symbols:
symbols.append(company_symbols[keyword.lower()])
return list(set(symbols))
def _get_stock_news(self, symbol: str) -> List[Dict]:
"""获取股票新闻"""
try:
url = API_CONFIG['alpha_vantage']['base_url']
params = {
'function': 'NEWS_SENTIMENT',
'tickers': symbol,
'apikey': self.alpha_vantage_key,
'limit': 10
}
response = requests.get(url, params=params, timeout=30)
response.raise_for_status()
data = response.json()
articles = []
for item in data.get('feed', []):
article = {
'id': f"alphavantage_{hash(item['url'])}",
'title': item['title'],
'content': item.get('summary', ''),
'summary': item.get('summary', ''),
'author': ','.join(item.get('authors', [])),
'original_url': item['url'],
'published_date': self._parse_date(item.get('time_published')),
'source_name': item.get('source', 'Alpha Vantage'),
'authority_level': 2,
'language': 'en',
'relevance_score': float(item.get('overall_sentiment_score', 0.5))
}
articles.append(article)
return articles
except Exception as e:
self.logger.error(f"获取 {symbol} 股票新闻失败: {e}")
return []
def _parse_date(self, date_str: str) -> Optional[datetime]:
"""解析日期字符串"""
if not date_str:
return None
try:
# 尝试多种日期格式
formats = [
'%Y-%m-%dT%H:%M:%SZ',
'%Y-%m-%dT%H:%M:%S',
'%Y%m%dT%H%M%S',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d'
]
for fmt in formats:
try:
return datetime.strptime(date_str, fmt)
except ValueError:
continue
return None
except Exception:
return None
def _process_results(self, results: List[Dict], keywords: List[str]) -> List[Dict]:
"""处理和排序搜索结果"""
if not results:
return []
# 去重基于URL
seen_urls = set()
unique_results = []
for result in results:
url = result.get('original_url', '')
if url and url not in seen_urls:
seen_urls.add(url)
unique_results.append(result)
# 计算最终相关性分数
for result in unique_results:
score = result.get('relevance_score', 0)
# 根据权威级别调整分数
authority_bonus = (4 - result.get('authority_level', 4)) * 0.2
score += authority_bonus
# 根据发布时间调整分数(越新越好)
pub_date = result.get('published_date')
if pub_date:
days_old = (datetime.now() - pub_date).days
time_factor = max(0, 1 - days_old / 30) # 30天内线性衰减
score += time_factor * 0.1
result['final_score'] = score
# 过滤低相关性结果
min_score = SEARCH_CONFIG['min_relevance_score']
filtered_results = [r for r in unique_results if r.get('final_score', 0) >= min_score]
# 按分数排序
filtered_results.sort(key=lambda x: x.get('final_score', 0), reverse=True)
# 限制结果数量
max_results = SEARCH_CONFIG['max_results_per_source'] * 2
return filtered_results[:max_results]
def _get_industry_domains(self, industry: str) -> List[str]:
"""获取行业相关域名"""
domain_map = {
'finance': [
'bloomberg.com', 'reuters.com', 'ft.com', 'wsj.com',
'cnbc.com', 'marketwatch.com', 'forbes.com'
],
'ai_software': [
'techcrunch.com', 'venturebeat.com', 'theverge.com',
'arstechnica.com', 'wired.com', 'technologyreview.com'
],
'healthcare_pharma': [
'statnews.com', 'fiercepharma.com', 'biopharmadive.com',
'nature.com', 'nejm.org'
]
}
return domain_map.get(industry, [])
def _get_sources_info(self, industry_id: Optional[int]) -> Dict:
"""获取搜索源信息"""
sources = self.db.get_rss_sources(industry_id)
return {
'total_sources': len(sources),
'by_authority': {
'1': len([s for s in sources if s['authority_level'] == 1]),
'2': len([s for s in sources if s['authority_level'] == 2]),
'3': len([s for s in sources if s['authority_level'] == 3])
}
}
def get_search_suggestions(self, partial_query: str, limit: int = 10) -> List[str]:
"""获取搜索建议"""
try:
# 基于历史搜索记录提供建议
history = self.db.get_search_history(limit=100)
suggestions = []
partial_lower = partial_query.lower()
for record in history:
keywords = record.get('keywords', '')
if partial_lower in keywords.lower() and keywords not in suggestions:
suggestions.append(keywords)
if len(suggestions) >= limit:
break
return suggestions
except Exception as e:
self.logger.error(f"获取搜索建议失败: {e}")
return []