461 lines
17 KiB
Python
461 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
搜索引擎主类
|
||
"""
|
||
|
||
import requests
|
||
import logging
|
||
import time
|
||
from typing import List, Dict, Optional, Tuple
|
||
from datetime import datetime, timedelta
|
||
|
||
from database import DatabaseManager
|
||
from config import API_CONFIG, SEARCH_CONFIG
|
||
|
||
class SearchEngine:
|
||
"""智能搜索引擎"""
|
||
|
||
def __init__(self):
|
||
self.db = DatabaseManager()
|
||
self.logger = logging.getLogger(__name__)
|
||
self.newsapi_key = API_CONFIG['newsapi']['key']
|
||
self.twitter_token = API_CONFIG['twitter']['bearer_token']
|
||
self.alpha_vantage_key = API_CONFIG['alpha_vantage']['key']
|
||
|
||
def search(self, query: str, industry: str = None,
|
||
language: str = None, user_ip: str = '') -> Dict:
|
||
"""执行搜索"""
|
||
start_time = time.time()
|
||
|
||
# 解析查询参数
|
||
search_params = self._parse_query(query, industry, language)
|
||
keywords = search_params['keywords']
|
||
industry_id = search_params['industry_id']
|
||
detected_language = search_params['language']
|
||
|
||
self.logger.info(f"开始搜索: {keywords}, 行业: {industry}, 语言: {detected_language}")
|
||
|
||
# 创建搜索记录
|
||
search_log_id = self.db.create_search_log(
|
||
keywords=' '.join(keywords),
|
||
industry_id=industry_id,
|
||
language=detected_language,
|
||
user_ip=user_ip
|
||
)
|
||
|
||
try:
|
||
# 多源搜索
|
||
all_results = []
|
||
|
||
# 1. 搜索本地数据库
|
||
db_results = self._search_database(keywords, industry_id, detected_language)
|
||
all_results.extend(db_results)
|
||
self.logger.info(f"数据库搜索结果: {len(db_results)} 条")
|
||
|
||
# 2. NewsAPI搜索(如果有API密钥)
|
||
if self.newsapi_key and detected_language == 'en':
|
||
news_results = self._search_newsapi(keywords, industry)
|
||
all_results.extend(news_results)
|
||
self.logger.info(f"NewsAPI搜索结果: {len(news_results)} 条")
|
||
|
||
# 3. 金融数据API搜索(金融行业)
|
||
if industry == 'finance' and self.alpha_vantage_key:
|
||
finance_results = self._search_financial_data(keywords)
|
||
all_results.extend(finance_results)
|
||
self.logger.info(f"金融数据搜索结果: {len(finance_results)} 条")
|
||
|
||
# 结果去重和排序
|
||
final_results = self._process_results(all_results, keywords)
|
||
|
||
# 保存搜索结果
|
||
if final_results:
|
||
self.db.save_search_results(search_log_id, final_results)
|
||
|
||
search_time = time.time() - start_time
|
||
|
||
return {
|
||
'success': True,
|
||
'search_log_id': search_log_id,
|
||
'query': query,
|
||
'keywords': keywords,
|
||
'industry': industry,
|
||
'language': detected_language,
|
||
'results': final_results,
|
||
'total_count': len(final_results),
|
||
'search_time': round(search_time, 2),
|
||
'sources_searched': self._get_sources_info(industry_id)
|
||
}
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"搜索过程出错: {e}")
|
||
return {
|
||
'success': False,
|
||
'error': str(e),
|
||
'search_log_id': search_log_id,
|
||
'query': query
|
||
}
|
||
|
||
def _parse_query(self, query: str, industry: str = None,
|
||
language: str = None) -> Dict:
|
||
"""解析搜索查询"""
|
||
# 提取关键词
|
||
keywords = self._extract_keywords(query)
|
||
|
||
# 检测语言
|
||
if not language:
|
||
language = self._detect_language(query)
|
||
|
||
# 获取行业ID
|
||
industry_id = None
|
||
if industry:
|
||
industries = self.db.get_industries()
|
||
for ind in industries:
|
||
if ind['name_en'] == industry:
|
||
industry_id = ind['id']
|
||
break
|
||
|
||
return {
|
||
'keywords': keywords,
|
||
'industry_id': industry_id,
|
||
'language': language
|
||
}
|
||
|
||
def _extract_keywords(self, query: str) -> List[str]:
|
||
"""提取搜索关键词"""
|
||
import re
|
||
|
||
# 基础关键词提取
|
||
words = re.findall(r'\b\w+\b', query.lower())
|
||
|
||
# 过滤停用词
|
||
stop_words = {
|
||
'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
|
||
'by', 'from', 'up', 'about', 'into', 'through', 'during', 'before',
|
||
'after', 'above', 'below', 'up', 'down', 'out', 'off', 'over', 'under',
|
||
'again', 'further', 'then', 'once', 'here', 'there', 'when', 'where',
|
||
'why', 'how', 'all', 'any', 'both', 'each', 'few', 'more', 'most',
|
||
'other', 'some', 'such', 'only', 'own', 'same', 'so', 'than', 'too',
|
||
'very', 'can', 'will', 'just', 'should', 'now', 'what', 'news',
|
||
'latest', 'recent', 'update', 'today', 'yesterday'
|
||
}
|
||
|
||
keywords = [word for word in words if len(word) > 2 and word not in stop_words]
|
||
|
||
# 保留原始查询中的重要短语
|
||
phrases = self._extract_phrases(query)
|
||
keywords.extend(phrases)
|
||
|
||
return list(set(keywords)) # 去重
|
||
|
||
def _extract_phrases(self, query: str) -> List[str]:
|
||
"""提取重要短语"""
|
||
import re
|
||
|
||
# 提取引号内的短语
|
||
quoted_phrases = re.findall(r'"([^"]*)"', query)
|
||
|
||
# 提取常见的技术术语和公司名
|
||
phrases = []
|
||
|
||
# 技术术语模式
|
||
tech_patterns = [
|
||
r'\b[A-Z]{2,}\b', # 大写缩写 (AI, API, GDP)
|
||
r'\b\w+\.\w+\b', # 域名格式
|
||
r'\b\w+-\w+\b', # 连字符词组
|
||
]
|
||
|
||
for pattern in tech_patterns:
|
||
matches = re.findall(pattern, query)
|
||
phrases.extend(matches)
|
||
|
||
phrases.extend(quoted_phrases)
|
||
return phrases
|
||
|
||
def _detect_language(self, query: str) -> str:
|
||
"""检测查询语言"""
|
||
# 检查是否包含中文特定关键词
|
||
china_keywords = SEARCH_CONFIG['keywords_for_china']
|
||
|
||
for keyword in china_keywords:
|
||
if keyword in query:
|
||
return 'cn'
|
||
|
||
# 检查是否包含中文字符
|
||
import re
|
||
chinese_chars = re.findall(r'[\u4e00-\u9fff]+', query)
|
||
if chinese_chars:
|
||
return 'cn'
|
||
|
||
return SEARCH_CONFIG['default_language']
|
||
|
||
def _search_database(self, keywords: List[str], industry_id: Optional[int],
|
||
language: str) -> List[Dict]:
|
||
"""搜索本地数据库"""
|
||
return self.db.search_articles(
|
||
keywords=keywords,
|
||
industry_id=industry_id,
|
||
language=language if language != 'cn' else None,
|
||
limit=SEARCH_CONFIG['max_results_per_source']
|
||
)
|
||
|
||
def _search_newsapi(self, keywords: List[str], industry: str = None) -> List[Dict]:
|
||
"""使用NewsAPI搜索"""
|
||
if not self.newsapi_key:
|
||
return []
|
||
|
||
try:
|
||
url = f"{API_CONFIG['newsapi']['base_url']}everything"
|
||
|
||
# 构建查询字符串
|
||
query_str = ' AND '.join(keywords[:5]) # 限制关键词数量
|
||
|
||
params = {
|
||
'q': query_str,
|
||
'apiKey': self.newsapi_key,
|
||
'language': 'en',
|
||
'sortBy': 'relevancy',
|
||
'pageSize': 20,
|
||
'from': (datetime.now() - timedelta(days=30)).isoformat()
|
||
}
|
||
|
||
# 添加行业相关域名
|
||
if industry:
|
||
domains = self._get_industry_domains(industry)
|
||
if domains:
|
||
params['domains'] = ','.join(domains)
|
||
|
||
response = requests.get(url, params=params, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
data = response.json()
|
||
articles = []
|
||
|
||
for article in data.get('articles', []):
|
||
processed_article = {
|
||
'id': f"newsapi_{hash(article['url'])}",
|
||
'title': article['title'],
|
||
'content': article.get('description', ''),
|
||
'summary': article.get('description', ''),
|
||
'author': article.get('author', ''),
|
||
'original_url': article['url'],
|
||
'published_date': self._parse_date(article.get('publishedAt')),
|
||
'source_name': article['source']['name'],
|
||
'authority_level': 2, # 默认主流媒体级别
|
||
'language': 'en',
|
||
'relevance_score': 0.8 # NewsAPI结果相关性较高
|
||
}
|
||
articles.append(processed_article)
|
||
|
||
self.logger.info(f"NewsAPI返回 {len(articles)} 条结果")
|
||
return articles
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"NewsAPI搜索失败: {e}")
|
||
return []
|
||
|
||
def _search_financial_data(self, keywords: List[str]) -> List[Dict]:
|
||
"""搜索金融数据"""
|
||
if not self.alpha_vantage_key:
|
||
return []
|
||
|
||
try:
|
||
# 检查关键词是否包含股票代码
|
||
stock_symbols = self._extract_stock_symbols(keywords)
|
||
if not stock_symbols:
|
||
return []
|
||
|
||
articles = []
|
||
for symbol in stock_symbols[:3]: # 限制查询数量
|
||
data = self._get_stock_news(symbol)
|
||
if data:
|
||
articles.extend(data)
|
||
|
||
return articles
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"金融数据搜索失败: {e}")
|
||
return []
|
||
|
||
def _extract_stock_symbols(self, keywords: List[str]) -> List[str]:
|
||
"""提取股票代码"""
|
||
import re
|
||
symbols = []
|
||
|
||
for keyword in keywords:
|
||
# 检查是否为股票代码格式
|
||
if re.match(r'^[A-Z]{1,5}$', keyword.upper()):
|
||
symbols.append(keyword.upper())
|
||
|
||
# 添加一些常见公司的股票代码映射
|
||
company_symbols = {
|
||
'apple': 'AAPL', 'microsoft': 'MSFT', 'google': 'GOOGL',
|
||
'amazon': 'AMZN', 'tesla': 'TSLA', 'meta': 'META',
|
||
'nvidia': 'NVDA', 'intel': 'INTC', 'amd': 'AMD'
|
||
}
|
||
|
||
for keyword in keywords:
|
||
if keyword.lower() in company_symbols:
|
||
symbols.append(company_symbols[keyword.lower()])
|
||
|
||
return list(set(symbols))
|
||
|
||
def _get_stock_news(self, symbol: str) -> List[Dict]:
|
||
"""获取股票新闻"""
|
||
try:
|
||
url = API_CONFIG['alpha_vantage']['base_url']
|
||
params = {
|
||
'function': 'NEWS_SENTIMENT',
|
||
'tickers': symbol,
|
||
'apikey': self.alpha_vantage_key,
|
||
'limit': 10
|
||
}
|
||
|
||
response = requests.get(url, params=params, timeout=30)
|
||
response.raise_for_status()
|
||
|
||
data = response.json()
|
||
articles = []
|
||
|
||
for item in data.get('feed', []):
|
||
article = {
|
||
'id': f"alphavantage_{hash(item['url'])}",
|
||
'title': item['title'],
|
||
'content': item.get('summary', ''),
|
||
'summary': item.get('summary', ''),
|
||
'author': ','.join(item.get('authors', [])),
|
||
'original_url': item['url'],
|
||
'published_date': self._parse_date(item.get('time_published')),
|
||
'source_name': item.get('source', 'Alpha Vantage'),
|
||
'authority_level': 2,
|
||
'language': 'en',
|
||
'relevance_score': float(item.get('overall_sentiment_score', 0.5))
|
||
}
|
||
articles.append(article)
|
||
|
||
return articles
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"获取 {symbol} 股票新闻失败: {e}")
|
||
return []
|
||
|
||
def _parse_date(self, date_str: str) -> Optional[datetime]:
|
||
"""解析日期字符串"""
|
||
if not date_str:
|
||
return None
|
||
|
||
try:
|
||
# 尝试多种日期格式
|
||
formats = [
|
||
'%Y-%m-%dT%H:%M:%SZ',
|
||
'%Y-%m-%dT%H:%M:%S',
|
||
'%Y%m%dT%H%M%S',
|
||
'%Y-%m-%d %H:%M:%S',
|
||
'%Y-%m-%d'
|
||
]
|
||
|
||
for fmt in formats:
|
||
try:
|
||
return datetime.strptime(date_str, fmt)
|
||
except ValueError:
|
||
continue
|
||
|
||
return None
|
||
except Exception:
|
||
return None
|
||
|
||
def _process_results(self, results: List[Dict], keywords: List[str]) -> List[Dict]:
|
||
"""处理和排序搜索结果"""
|
||
if not results:
|
||
return []
|
||
|
||
# 去重(基于URL)
|
||
seen_urls = set()
|
||
unique_results = []
|
||
|
||
for result in results:
|
||
url = result.get('original_url', '')
|
||
if url and url not in seen_urls:
|
||
seen_urls.add(url)
|
||
unique_results.append(result)
|
||
|
||
# 计算最终相关性分数
|
||
for result in unique_results:
|
||
score = result.get('relevance_score', 0)
|
||
|
||
# 根据权威级别调整分数
|
||
authority_bonus = (4 - result.get('authority_level', 4)) * 0.2
|
||
score += authority_bonus
|
||
|
||
# 根据发布时间调整分数(越新越好)
|
||
pub_date = result.get('published_date')
|
||
if pub_date:
|
||
days_old = (datetime.now() - pub_date).days
|
||
time_factor = max(0, 1 - days_old / 30) # 30天内线性衰减
|
||
score += time_factor * 0.1
|
||
|
||
result['final_score'] = score
|
||
|
||
# 过滤低相关性结果
|
||
min_score = SEARCH_CONFIG['min_relevance_score']
|
||
filtered_results = [r for r in unique_results if r.get('final_score', 0) >= min_score]
|
||
|
||
# 按分数排序
|
||
filtered_results.sort(key=lambda x: x.get('final_score', 0), reverse=True)
|
||
|
||
# 限制结果数量
|
||
max_results = SEARCH_CONFIG['max_results_per_source'] * 2
|
||
return filtered_results[:max_results]
|
||
|
||
def _get_industry_domains(self, industry: str) -> List[str]:
|
||
"""获取行业相关域名"""
|
||
domain_map = {
|
||
'finance': [
|
||
'bloomberg.com', 'reuters.com', 'ft.com', 'wsj.com',
|
||
'cnbc.com', 'marketwatch.com', 'forbes.com'
|
||
],
|
||
'ai_software': [
|
||
'techcrunch.com', 'venturebeat.com', 'theverge.com',
|
||
'arstechnica.com', 'wired.com', 'technologyreview.com'
|
||
],
|
||
'healthcare_pharma': [
|
||
'statnews.com', 'fiercepharma.com', 'biopharmadive.com',
|
||
'nature.com', 'nejm.org'
|
||
]
|
||
}
|
||
|
||
return domain_map.get(industry, [])
|
||
|
||
def _get_sources_info(self, industry_id: Optional[int]) -> Dict:
|
||
"""获取搜索源信息"""
|
||
sources = self.db.get_rss_sources(industry_id)
|
||
|
||
return {
|
||
'total_sources': len(sources),
|
||
'by_authority': {
|
||
'1': len([s for s in sources if s['authority_level'] == 1]),
|
||
'2': len([s for s in sources if s['authority_level'] == 2]),
|
||
'3': len([s for s in sources if s['authority_level'] == 3])
|
||
}
|
||
}
|
||
|
||
def get_search_suggestions(self, partial_query: str, limit: int = 10) -> List[str]:
|
||
"""获取搜索建议"""
|
||
try:
|
||
# 基于历史搜索记录提供建议
|
||
history = self.db.get_search_history(limit=100)
|
||
suggestions = []
|
||
|
||
partial_lower = partial_query.lower()
|
||
|
||
for record in history:
|
||
keywords = record.get('keywords', '')
|
||
if partial_lower in keywords.lower() and keywords not in suggestions:
|
||
suggestions.append(keywords)
|
||
if len(suggestions) >= limit:
|
||
break
|
||
|
||
return suggestions
|
||
|
||
except Exception as e:
|
||
self.logger.error(f"获取搜索建议失败: {e}")
|
||
return [] |