229 lines
9.3 KiB
Python
229 lines
9.3 KiB
Python
|
# app.py
|
|||
|
print("[DEBUG] main.py started")
|
|||
|
import sys
|
|||
|
sys.stdout.flush()
|
|||
|
import json
|
|||
|
import threading
|
|||
|
import requests
|
|||
|
import asyncio
|
|||
|
import json
|
|||
|
import threading
|
|||
|
import math
|
|||
|
from flask import Flask, request, jsonify
|
|||
|
from utils.springerLink import springerLink # 你的爬虫接口
|
|||
|
from utils.arxiv import arxiv # 你的爬虫接口
|
|||
|
from utils.pubmed import pubmed # 你的爬虫接口
|
|||
|
from utils.wangfang import wangfang # 你的爬虫接口
|
|||
|
from utils.zhiwang import zhiwang # 你的爬虫接口
|
|||
|
from utils.weipu import weipu # 你的爬虫接口
|
|||
|
from utils.ieeeXplore import ieeeXplore
|
|||
|
from parseApi.api import parse_ieee_results_all_categories_async
|
|||
|
from flask_cors import CORS
|
|||
|
from config import MAX_CONCURRENT_BROWSERS,api_info
|
|||
|
app = Flask(__name__)
|
|||
|
CORS(app, resources={r"/*": {"origins": "*"}}, supports_credentials=True, allow_headers="*")
|
|||
|
# 允许所有跨域请求
|
|||
|
semaphore = threading.Semaphore(MAX_CONCURRENT_BROWSERS)
|
|||
|
# 假设 SITE_FUNCTIONS 分为中文网站和英文网站函数列表
|
|||
|
CHINESE_SITE_FUNCTIONS = [zhiwang, wangfang, weipu]
|
|||
|
ENGLISH_SITE_FUNCTIONS = [ieeeXplore, arxiv, pubmed]
|
|||
|
def translate_text(text):
|
|||
|
"""
|
|||
|
输入:
|
|||
|
text_input: 一句话或中文关键词列表 (str)
|
|||
|
api_info: dict, 包含 base_url, api_key, model
|
|||
|
输出:
|
|||
|
dict: {"chinese": [...], "english": [...]}
|
|||
|
"""
|
|||
|
if not text:
|
|||
|
return {"chinese": [], "english": []}
|
|||
|
|
|||
|
# 构造 prompt
|
|||
|
prompt = (
|
|||
|
"你是科研助手,输入是一句话或中文关键词列表。"
|
|||
|
"请从输入中理解语义,提取与科研论文主题最相关、最核心的中文主题,并翻译为英文。"
|
|||
|
"只保留1~2个最核心主题,不要加入无关内容。"
|
|||
|
"输出必须严格遵守 JSON 格式,不允许有额外文字或符号:{\"chinese\": [...], \"english\": [...]}。\n"
|
|||
|
"示例输入输出:\n"
|
|||
|
"输入: '我想获取基于深度学习的图像识别方面的研究'\n"
|
|||
|
"输出: {\"chinese\": [\"基于深度学习的图像识别\"], \"english\": [\"Deep Learning-based Image Recognition\"]}\n"
|
|||
|
"输入: '图像识别在深度学习方面的研究'\n"
|
|||
|
"输出: {\"chinese\": [\"基于深度学习的图像识别\"], \"english\": [\"Deep Learning-based Image Recognition\"]}\n"
|
|||
|
"输入: '自然语言处理模型在文本分类中的应用'\n"
|
|||
|
"输出: {\"chinese\": [\"自然语言处理文本分类\"], \"english\": [\"NLP Text Classification\"]}\n"
|
|||
|
"输入: '强化学习在自动驾驶决策中的最新进展'\n"
|
|||
|
"输出: {\"chinese\": [\"强化学习自动驾驶决策\"], \"english\": [\"Reinforcement Learning for Autonomous Driving Decision-Making\"]}\n"
|
|||
|
"输入: '使用图神经网络进行社交网络分析的研究'\n"
|
|||
|
"输出: {\"chinese\": [\"图神经网络社交网络分析\"], \"english\": [\"Graph Neural Networks for Social Network Analysis\"]}\n"
|
|||
|
"输入: '我想研究深度强化学习在机器人控制中的应用'\n"
|
|||
|
"输出: {\"chinese\": [\"深度强化学习机器人控制\"], \"english\": [\"Deep Reinforcement Learning for Robot Control\"]}\n"
|
|||
|
f"现在请对输入提取核心主题:\n输入: {text}"
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
url = f"{api_info['base_url']}/chat/completions"
|
|||
|
headers = {
|
|||
|
"Content-Type": "application/json",
|
|||
|
"Authorization": f"Bearer {api_info['api_key']}"
|
|||
|
}
|
|||
|
payload = {
|
|||
|
"model": api_info["model"],
|
|||
|
"messages": [{"role": "user", "content": prompt}],
|
|||
|
"max_output_tokens": 512
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
resp = requests.post(url, headers=headers, json=payload, timeout=30)
|
|||
|
resp.raise_for_status()
|
|||
|
result = resp.json()
|
|||
|
|
|||
|
text_output = result.get("choices", [{}])[0].get("message", {}).get("content", "")
|
|||
|
if not text_output:
|
|||
|
return {"chinese": [text], "english": []}
|
|||
|
|
|||
|
try:
|
|||
|
parsed = json.loads(text_output)
|
|||
|
chinese = parsed.get("chinese", [text])
|
|||
|
english = parsed.get("english", [])
|
|||
|
return {"chinese": chinese, "english": english}
|
|||
|
except json.JSONDecodeError:
|
|||
|
return {"chinese": [text], "english": []}
|
|||
|
|
|||
|
except requests.RequestException as e:
|
|||
|
print(f"[ERROR] 请求失败: {e}")
|
|||
|
return {"chinese": [text], "english": []}
|
|||
|
async def crawl_single(keyword, site_func, limit, sort):
|
|||
|
loop = asyncio.get_event_loop()
|
|||
|
try:
|
|||
|
print(f"[DEBUG] Opening browser for {site_func.__name__} with keyword '{keyword}'")
|
|||
|
result = await loop.run_in_executor(
|
|||
|
None,
|
|||
|
lambda: site_func(keyword, limit, sort_options=sort)
|
|||
|
)
|
|||
|
print(f"[DEBUG] Finished crawling {site_func.__name__} with keyword '{keyword}'")
|
|||
|
return result
|
|||
|
except Exception as e:
|
|||
|
print(f"[ERROR] {site_func.__name__} with keyword '{keyword}' failed: {e}")
|
|||
|
return []
|
|||
|
async def crawl_and_parse(kw, site_func, limit, sort, parse_flag):
|
|||
|
try:
|
|||
|
results = await crawl_single(kw, site_func, limit, sort)
|
|||
|
if parse_flag and results:
|
|||
|
print("解析之前的数据:", results)
|
|||
|
parsed_results = await parse_ieee_results_all_categories_async(results)
|
|||
|
print(f"[DEBUG] 解析结果: {parsed_results}")
|
|||
|
return parsed_results or []
|
|||
|
return results or []
|
|||
|
except Exception as e:
|
|||
|
print(f"[ERROR] {site_func.__name__} with keyword '{kw}' failed: {e}")
|
|||
|
return []
|
|||
|
|
|||
|
|
|||
|
# crawl_all_keywords 不需要改太多,只需保持 semaphore 控制并发即可
|
|||
|
async def crawl_all_keywords(chinese_keywords, english_keywords, limit, sort, max_concurrent=MAX_CONCURRENT_BROWSERS, parse_flag=True):
|
|||
|
all_tasks = []
|
|||
|
|
|||
|
# 中文
|
|||
|
for kw in chinese_keywords:
|
|||
|
for func in CHINESE_SITE_FUNCTIONS:
|
|||
|
all_tasks.append((kw, func))
|
|||
|
# 英文
|
|||
|
for kw in english_keywords:
|
|||
|
for func in ENGLISH_SITE_FUNCTIONS:
|
|||
|
all_tasks.append((kw, func))
|
|||
|
|
|||
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|||
|
|
|||
|
async def sem_task(kw, func):
|
|||
|
async with semaphore:
|
|||
|
return await crawl_and_parse(kw, func, limit, sort, parse_flag)
|
|||
|
|
|||
|
tasks = [sem_task(kw, func) for kw, func in all_tasks]
|
|||
|
all_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|||
|
|
|||
|
final_results = []
|
|||
|
weipu_empty = [] # 记录哪些关键词的 weipu 结果为空
|
|||
|
|
|||
|
# 处理第一次抓取的结果
|
|||
|
for (kw, func), r in zip(all_tasks, all_results):
|
|||
|
if isinstance(r, dict):
|
|||
|
for category, papers in r.items():
|
|||
|
final_results.extend(papers)
|
|||
|
elif isinstance(r, list):
|
|||
|
final_results.extend(r)
|
|||
|
# 如果是 weipu 且返回空列表,记录下来
|
|||
|
if func is weipu and not r:
|
|||
|
weipu_empty.append(kw)
|
|||
|
|
|||
|
# ---- 仅增加的逻辑:对 weipu 结果为空的关键词重试 ----
|
|||
|
for kw in weipu_empty:
|
|||
|
try:
|
|||
|
print(f"[INFO] Weipu empty for '{kw}', retrying...")
|
|||
|
retry_res = await crawl_and_parse(kw, weipu, limit, sort, parse_flag)
|
|||
|
if isinstance(retry_res, dict):
|
|||
|
for category, papers in retry_res.items():
|
|||
|
final_results.extend(papers)
|
|||
|
elif isinstance(retry_res, list):
|
|||
|
final_results.extend(retry_res)
|
|||
|
except Exception as e:
|
|||
|
print(f"[ERROR] Weipu retry failed for '{kw}': {e}")
|
|||
|
# ---------------------------------------------------------
|
|||
|
|
|||
|
return final_results
|
|||
|
|
|||
|
@app.route("/crawl", methods=["POST", "OPTIONS"])
|
|||
|
def crawl():
|
|||
|
if request.method == "OPTIONS":
|
|||
|
return jsonify({"status": "ok"}), 200
|
|||
|
data = request.json
|
|||
|
if not data or "texts" not in data:
|
|||
|
return jsonify({"success": False, "error": "Missing 'texts' field"}), 400
|
|||
|
|
|||
|
text_input = data["texts"]
|
|||
|
parse_flag = data.get("parse", True)
|
|||
|
print("自然语言处理文本",text_input)
|
|||
|
sort = data.get("sort", ["relevance"])
|
|||
|
max_concurrent = int(data.get("max_concurrent", 3))
|
|||
|
|
|||
|
max_retries = 3
|
|||
|
translated = translate_text(text_input)
|
|||
|
chinese_keywords = translated.get("chinese", [])
|
|||
|
english_keywords = translated.get("english", [])
|
|||
|
|
|||
|
retry_count = 0
|
|||
|
while not english_keywords and retry_count < max_retries:
|
|||
|
retry_count += 1
|
|||
|
retry_translated = translate_text(text_input)
|
|||
|
# 中文关键词保留第一次或最新结果
|
|||
|
chinese_keywords = retry_translated.get("chinese", chinese_keywords)
|
|||
|
english_keywords = retry_translated.get("english", [])
|
|||
|
if english_keywords:
|
|||
|
break # 获取到英文关键词,停止重试
|
|||
|
|
|||
|
print(translated)
|
|||
|
|
|||
|
raw_limit = data.get("limit")
|
|||
|
if raw_limit is not None:
|
|||
|
raw_limit = int(raw_limit)
|
|||
|
total_tasks = len(chinese_keywords) * 3 + len(english_keywords) * 3
|
|||
|
limit = max(1, math.ceil(raw_limit / total_tasks)) # 每个网页的 limit 至少 1
|
|||
|
else:
|
|||
|
limit=10
|
|||
|
|
|||
|
loop = asyncio.new_event_loop()
|
|||
|
asyncio.set_event_loop(loop)
|
|||
|
|
|||
|
async def main():
|
|||
|
results = await crawl_all_keywords(chinese_keywords, english_keywords, limit, sort, max_concurrent, parse_flag)
|
|||
|
return results
|
|||
|
|
|||
|
try:
|
|||
|
final_results = loop.run_until_complete(main())
|
|||
|
return jsonify({"success": True, "results": final_results})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"success": False, "error": str(e)}), 500
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
|
|||
|
|