209 lines
7.8 KiB
Python
209 lines
7.8 KiB
Python
|
import asyncio
|
|||
|
import json
|
|||
|
|
|||
|
from django.http import StreamingHttpResponse
|
|||
|
from django.shortcuts import render
|
|||
|
import aiohttp
|
|||
|
from rest_framework.decorators import api_view
|
|||
|
import asyncio
|
|||
|
from django_filters.rest_framework import DjangoFilterBackend
|
|||
|
# Create your views here.
|
|||
|
from rest_framework import viewsets, filters
|
|||
|
from rest_framework.pagination import PageNumberPagination
|
|||
|
from .models import Task, TaskDetail
|
|||
|
from .serializers import TaskSerializer, TaskDetailSerializer, TaskListSerializer
|
|||
|
from rest_framework.decorators import action
|
|||
|
from rest_framework.response import Response
|
|||
|
from rest_framework import status
|
|||
|
from .tasks import trigger_task_execution
|
|||
|
import threading
|
|||
|
# 分页设置
|
|||
|
class StandardResultsSetPagination(PageNumberPagination):
|
|||
|
page_size = 10
|
|||
|
page_size_query_param = 'page_size'
|
|||
|
max_page_size = 100
|
|||
|
from selenium_django.settings import api_info
|
|||
|
|
|||
|
def sync_stream(generator):
|
|||
|
"""将异步迭代器包装为同步迭代器"""
|
|||
|
loop = asyncio.new_event_loop()
|
|||
|
asyncio.set_event_loop(loop)
|
|||
|
|
|||
|
async_gen = generator
|
|||
|
try:
|
|||
|
while True:
|
|||
|
try:
|
|||
|
# 获取异步生成器的下一条数据
|
|||
|
chunk = loop.run_until_complete(async_gen.__anext__())
|
|||
|
if chunk and chunk.strip():
|
|||
|
yield chunk
|
|||
|
except StopAsyncIteration:
|
|||
|
break
|
|||
|
finally:
|
|||
|
loop.close()
|
|||
|
|
|||
|
async def call_model_stream(messages):
|
|||
|
url = f"{api_info['base_url']}/chat/completions"
|
|||
|
headers = {
|
|||
|
"Content-Type": "application/json",
|
|||
|
"Authorization": f"Bearer {api_info['api_key']}"
|
|||
|
}
|
|||
|
payload = {
|
|||
|
"model": api_info["model"],
|
|||
|
"messages": messages,
|
|||
|
"max_output_tokens": 1024,
|
|||
|
"stream": True
|
|||
|
}
|
|||
|
|
|||
|
async with aiohttp.ClientSession() as session:
|
|||
|
async with session.post(url, headers=headers, json=payload) as resp:
|
|||
|
async for line in resp.content:
|
|||
|
if line:
|
|||
|
line_str = line.decode().strip()
|
|||
|
if line_str.startswith("data: "):
|
|||
|
data_str = line_str[len("data: "):]
|
|||
|
if data_str == "[DONE]":
|
|||
|
break
|
|||
|
data_json = json.loads(data_str)
|
|||
|
delta = data_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
|||
|
if delta and delta.strip(): # 非空才 yield
|
|||
|
yield delta
|
|||
|
class TaskViewSet(viewsets.ModelViewSet):
|
|||
|
queryset = Task.objects.all().order_by('-created_at')
|
|||
|
pagination_class = StandardResultsSetPagination
|
|||
|
filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter]
|
|||
|
filterset_fields = ['task_id', 'status']
|
|||
|
search_fields = ['name', 'site']
|
|||
|
ordering_fields = ['created_at', 'updated_at']
|
|||
|
|
|||
|
def get_serializer_class(self):
|
|||
|
if self.action == 'list':
|
|||
|
return TaskListSerializer # list 返回简化字段
|
|||
|
return TaskSerializer # retrieve 返回完整字段,含 details
|
|||
|
|
|||
|
@action(detail=True, methods=["post"])
|
|||
|
def trigger(self, request, pk=None):
|
|||
|
task = self.get_object()
|
|||
|
|
|||
|
try:
|
|||
|
# 异步触发 Celery 任务
|
|||
|
async_result = trigger_task_execution.delay(task.id)
|
|||
|
|
|||
|
# 直接返回任务已触发,不访问 async_result 的内容
|
|||
|
return Response({
|
|||
|
"success": True,
|
|||
|
"task_id": async_result.id,
|
|||
|
"message": f"任务 {task.id} 已触发"
|
|||
|
}, status=status.HTTP_200_OK)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
return Response({
|
|||
|
"success": False,
|
|||
|
"message": str(e)
|
|||
|
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|||
|
|
|||
|
@action(detail=True, methods=['post'])
|
|||
|
def chat(self, request, pk=None):
|
|||
|
task = self.get_object()
|
|||
|
user_question = request.data.get("question", "")
|
|||
|
if not user_question:
|
|||
|
return Response({"success": False, "message": "question 参数不能为空"}, status=400)
|
|||
|
|
|||
|
# 构造结构化文档
|
|||
|
all_docs = TaskDetail.objects.filter(task=task)
|
|||
|
all_docs_list = []
|
|||
|
for doc in all_docs:
|
|||
|
all_docs_list.append({
|
|||
|
"title": doc.title or "",
|
|||
|
"summary": doc.summary or "",
|
|||
|
"parsed_summary": doc.parsed_summary or "",
|
|||
|
"author": doc.author or "",
|
|||
|
"original_link": doc.original_link or "",
|
|||
|
"pdf_url": doc.pdf_url or "",
|
|||
|
"source": doc.source or "",
|
|||
|
"keywords": doc.keywords or ""
|
|||
|
})
|
|||
|
all_docs_json = json.dumps(all_docs_list, ensure_ascii=False)
|
|||
|
|
|||
|
SYSTEM_PROMPT = """
|
|||
|
你是专业文献问答助手。请严格根据提供的任务文档回答用户问题。
|
|||
|
任务文档内容已经结构化提供为 JSON 列表,每条文档包含字段:
|
|||
|
"title", "summary", "parsed_summary", "author", "original_link", "pdf_url", "source", "keywords"
|
|||
|
|
|||
|
要求:
|
|||
|
1. 仅基于文档内容作答,不补充外部知识。
|
|||
|
2. 输出只需针对用户问题作答,不输出整个 JSON。
|
|||
|
3. 如果文档中缺失相关信息,可以说明“未提供”。
|
|||
|
4. 保持输出可读,不包含多余内容或额外 JSON 结构。
|
|||
|
"""
|
|||
|
|
|||
|
messages = [
|
|||
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|||
|
{"role": "user", "content": f"任务文档内容:\n{all_docs_json}\n用户问题: {user_question}"}
|
|||
|
]
|
|||
|
|
|||
|
# 使用 Django 的 StreamingHttpResponse 返回
|
|||
|
response = StreamingHttpResponse(sync_stream(call_model_stream(messages)), content_type="text/event-stream")
|
|||
|
return response
|
|||
|
from rest_framework import status
|
|||
|
from rest_framework.response import Response
|
|||
|
|
|||
|
class TaskDetailViewSet(viewsets.ModelViewSet):
|
|||
|
queryset = TaskDetail.objects.all().order_by('-created_at')
|
|||
|
serializer_class = TaskDetailSerializer
|
|||
|
pagination_class = StandardResultsSetPagination
|
|||
|
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
|
|||
|
search_fields = ['title', 'author', 'site']
|
|||
|
|
|||
|
def get_queryset(self):
|
|||
|
queryset = super().get_queryset()
|
|||
|
task_id = self.request.query_params.get('task')
|
|||
|
if task_id and task_id.isdigit():
|
|||
|
queryset = queryset.filter(task_id=int(task_id))
|
|||
|
# Python 层面单任务去重
|
|||
|
seen_titles = set()
|
|||
|
unique_queryset = []
|
|||
|
for obj in queryset:
|
|||
|
if obj.title not in seen_titles:
|
|||
|
unique_queryset.append(obj)
|
|||
|
seen_titles.add(obj.title)
|
|||
|
return unique_queryset
|
|||
|
return queryset
|
|||
|
|
|||
|
def create(self, request, *args, **kwargs):
|
|||
|
"""
|
|||
|
在原生 create 接口中实现单任务增量插入
|
|||
|
"""
|
|||
|
task_id = request.data.get('task_id')
|
|||
|
if not task_id:
|
|||
|
return Response({"detail": "缺少 task_id"}, status=status.HTTP_400_BAD_REQUEST)
|
|||
|
|
|||
|
data_list = request.data.get('data', [])
|
|||
|
if not data_list:
|
|||
|
return Response({"detail": "缺少 data"}, status=status.HTTP_400_BAD_REQUEST)
|
|||
|
|
|||
|
added_count = 0
|
|||
|
skipped_titles = []
|
|||
|
|
|||
|
for data in data_list:
|
|||
|
title = data.get('title')
|
|||
|
if not title:
|
|||
|
continue
|
|||
|
|
|||
|
# 判断同一任务下是否已存在
|
|||
|
if TaskDetail.objects.filter(task_id=task_id, title=title).exists():
|
|||
|
skipped_titles.append(title)
|
|||
|
continue
|
|||
|
|
|||
|
# 不存在则创建
|
|||
|
serializer = self.get_serializer(data={**data, "task_id": task_id})
|
|||
|
serializer.is_valid(raise_exception=True)
|
|||
|
serializer.save()
|
|||
|
added_count += 1
|
|||
|
|
|||
|
return Response({
|
|||
|
"added_count": added_count,
|
|||
|
"skipped_titles": skipped_titles
|
|||
|
}, status=status.HTTP_201_CREATED)
|
|||
|
|