selenium_keyan/selenium_django/api/views.py

209 lines
7.8 KiB
Python
Executable File

import asyncio
import json
from django.http import StreamingHttpResponse
from django.shortcuts import render
import aiohttp
from rest_framework.decorators import api_view
import asyncio
from django_filters.rest_framework import DjangoFilterBackend
# Create your views here.
from rest_framework import viewsets, filters
from rest_framework.pagination import PageNumberPagination
from .models import Task, TaskDetail
from .serializers import TaskSerializer, TaskDetailSerializer, TaskListSerializer
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework import status
from .tasks import trigger_task_execution
import threading
# 分页设置
class StandardResultsSetPagination(PageNumberPagination):
page_size = 10
page_size_query_param = 'page_size'
max_page_size = 100
from selenium_django.settings import api_info
def sync_stream(generator):
"""将异步迭代器包装为同步迭代器"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async_gen = generator
try:
while True:
try:
# 获取异步生成器的下一条数据
chunk = loop.run_until_complete(async_gen.__anext__())
if chunk and chunk.strip():
yield chunk
except StopAsyncIteration:
break
finally:
loop.close()
async def call_model_stream(messages):
url = f"{api_info['base_url']}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_info['api_key']}"
}
payload = {
"model": api_info["model"],
"messages": messages,
"max_output_tokens": 1024,
"stream": True
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=payload) as resp:
async for line in resp.content:
if line:
line_str = line.decode().strip()
if line_str.startswith("data: "):
data_str = line_str[len("data: "):]
if data_str == "[DONE]":
break
data_json = json.loads(data_str)
delta = data_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
if delta and delta.strip(): # 非空才 yield
yield delta
class TaskViewSet(viewsets.ModelViewSet):
queryset = Task.objects.all().order_by('-created_at')
pagination_class = StandardResultsSetPagination
filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter]
filterset_fields = ['task_id', 'status']
search_fields = ['name', 'site']
ordering_fields = ['created_at', 'updated_at']
def get_serializer_class(self):
if self.action == 'list':
return TaskListSerializer # list 返回简化字段
return TaskSerializer # retrieve 返回完整字段,含 details
@action(detail=True, methods=["post"])
def trigger(self, request, pk=None):
task = self.get_object()
try:
# 异步触发 Celery 任务
async_result = trigger_task_execution.delay(task.id)
# 直接返回任务已触发,不访问 async_result 的内容
return Response({
"success": True,
"task_id": async_result.id,
"message": f"任务 {task.id} 已触发"
}, status=status.HTTP_200_OK)
except Exception as e:
return Response({
"success": False,
"message": str(e)
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@action(detail=True, methods=['post'])
def chat(self, request, pk=None):
task = self.get_object()
user_question = request.data.get("question", "")
if not user_question:
return Response({"success": False, "message": "question 参数不能为空"}, status=400)
# 构造结构化文档
all_docs = TaskDetail.objects.filter(task=task)
all_docs_list = []
for doc in all_docs:
all_docs_list.append({
"title": doc.title or "",
"summary": doc.summary or "",
"parsed_summary": doc.parsed_summary or "",
"author": doc.author or "",
"original_link": doc.original_link or "",
"pdf_url": doc.pdf_url or "",
"source": doc.source or "",
"keywords": doc.keywords or ""
})
all_docs_json = json.dumps(all_docs_list, ensure_ascii=False)
SYSTEM_PROMPT = """
你是专业文献问答助手。请严格根据提供的任务文档回答用户问题。
任务文档内容已经结构化提供为 JSON 列表,每条文档包含字段:
"title", "summary", "parsed_summary", "author", "original_link", "pdf_url", "source", "keywords"
要求:
1. 仅基于文档内容作答,不补充外部知识。
2. 输出只需针对用户问题作答,不输出整个 JSON。
3. 如果文档中缺失相关信息,可以说明“未提供”。
4. 保持输出可读,不包含多余内容或额外 JSON 结构。
"""
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"任务文档内容:\n{all_docs_json}\n用户问题: {user_question}"}
]
# 使用 Django 的 StreamingHttpResponse 返回
response = StreamingHttpResponse(sync_stream(call_model_stream(messages)), content_type="text/event-stream")
return response
from rest_framework import status
from rest_framework.response import Response
class TaskDetailViewSet(viewsets.ModelViewSet):
queryset = TaskDetail.objects.all().order_by('-created_at')
serializer_class = TaskDetailSerializer
pagination_class = StandardResultsSetPagination
filter_backends = [filters.SearchFilter, filters.OrderingFilter]
search_fields = ['title', 'author', 'site']
def get_queryset(self):
queryset = super().get_queryset()
task_id = self.request.query_params.get('task')
if task_id and task_id.isdigit():
queryset = queryset.filter(task_id=int(task_id))
# Python 层面单任务去重
seen_titles = set()
unique_queryset = []
for obj in queryset:
if obj.title not in seen_titles:
unique_queryset.append(obj)
seen_titles.add(obj.title)
return unique_queryset
return queryset
def create(self, request, *args, **kwargs):
"""
在原生 create 接口中实现单任务增量插入
"""
task_id = request.data.get('task_id')
if not task_id:
return Response({"detail": "缺少 task_id"}, status=status.HTTP_400_BAD_REQUEST)
data_list = request.data.get('data', [])
if not data_list:
return Response({"detail": "缺少 data"}, status=status.HTTP_400_BAD_REQUEST)
added_count = 0
skipped_titles = []
for data in data_list:
title = data.get('title')
if not title:
continue
# 判断同一任务下是否已存在
if TaskDetail.objects.filter(task_id=task_id, title=title).exists():
skipped_titles.append(title)
continue
# 不存在则创建
serializer = self.get_serializer(data={**data, "task_id": task_id})
serializer.is_valid(raise_exception=True)
serializer.save()
added_count += 1
return Response({
"added_count": added_count,
"skipped_titles": skipped_titles
}, status=status.HTTP_201_CREATED)