import asyncio import json from django.http import StreamingHttpResponse from django.shortcuts import render import aiohttp from rest_framework.decorators import api_view import asyncio from django_filters.rest_framework import DjangoFilterBackend # Create your views here. from rest_framework import viewsets, filters from rest_framework.pagination import PageNumberPagination from .models import Task, TaskDetail from .serializers import TaskSerializer, TaskDetailSerializer, TaskListSerializer from rest_framework.decorators import action from rest_framework.response import Response from rest_framework import status from .tasks import trigger_task_execution import threading # 分页设置 class StandardResultsSetPagination(PageNumberPagination): page_size = 10 page_size_query_param = 'page_size' max_page_size = 100 from selenium_django.settings import api_info def sync_stream(generator): """将异步迭代器包装为同步迭代器""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) async_gen = generator try: while True: try: # 获取异步生成器的下一条数据 chunk = loop.run_until_complete(async_gen.__anext__()) if chunk and chunk.strip(): yield chunk except StopAsyncIteration: break finally: loop.close() async def call_model_stream(messages): url = f"{api_info['base_url']}/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_info['api_key']}" } payload = { "model": api_info["model"], "messages": messages, "max_output_tokens": 1024, "stream": True } async with aiohttp.ClientSession() as session: async with session.post(url, headers=headers, json=payload) as resp: async for line in resp.content: if line: line_str = line.decode().strip() if line_str.startswith("data: "): data_str = line_str[len("data: "):] if data_str == "[DONE]": break data_json = json.loads(data_str) delta = data_json.get("choices", [{}])[0].get("delta", {}).get("content", "") if delta and delta.strip(): # 非空才 yield yield delta class TaskViewSet(viewsets.ModelViewSet): queryset = Task.objects.all().order_by('-created_at') pagination_class = StandardResultsSetPagination filter_backends = [DjangoFilterBackend, filters.SearchFilter, filters.OrderingFilter] filterset_fields = ['task_id', 'status'] search_fields = ['name', 'site'] ordering_fields = ['created_at', 'updated_at'] def get_serializer_class(self): if self.action == 'list': return TaskListSerializer # list 返回简化字段 return TaskSerializer # retrieve 返回完整字段,含 details @action(detail=True, methods=["post"]) def trigger(self, request, pk=None): task = self.get_object() try: # 异步触发 Celery 任务 async_result = trigger_task_execution.delay(task.id) # 直接返回任务已触发,不访问 async_result 的内容 return Response({ "success": True, "task_id": async_result.id, "message": f"任务 {task.id} 已触发" }, status=status.HTTP_200_OK) except Exception as e: return Response({ "success": False, "message": str(e) }, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @action(detail=True, methods=['post']) def chat(self, request, pk=None): task = self.get_object() user_question = request.data.get("question", "") if not user_question: return Response({"success": False, "message": "question 参数不能为空"}, status=400) # 构造结构化文档 all_docs = TaskDetail.objects.filter(task=task) all_docs_list = [] for doc in all_docs: all_docs_list.append({ "title": doc.title or "", "summary": doc.summary or "", "parsed_summary": doc.parsed_summary or "", "author": doc.author or "", "original_link": doc.original_link or "", "pdf_url": doc.pdf_url or "", "source": doc.source or "", "keywords": doc.keywords or "" }) all_docs_json = json.dumps(all_docs_list, ensure_ascii=False) SYSTEM_PROMPT = """ 你是专业文献问答助手。请严格根据提供的任务文档回答用户问题。 任务文档内容已经结构化提供为 JSON 列表,每条文档包含字段: "title", "summary", "parsed_summary", "author", "original_link", "pdf_url", "source", "keywords" 要求: 1. 仅基于文档内容作答,不补充外部知识。 2. 输出只需针对用户问题作答,不输出整个 JSON。 3. 如果文档中缺失相关信息,可以说明“未提供”。 4. 保持输出可读,不包含多余内容或额外 JSON 结构。 """ messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"任务文档内容:\n{all_docs_json}\n用户问题: {user_question}"} ] # 使用 Django 的 StreamingHttpResponse 返回 response = StreamingHttpResponse(sync_stream(call_model_stream(messages)), content_type="text/event-stream") return response from rest_framework import status from rest_framework.response import Response class TaskDetailViewSet(viewsets.ModelViewSet): queryset = TaskDetail.objects.all().order_by('-created_at') serializer_class = TaskDetailSerializer pagination_class = StandardResultsSetPagination filter_backends = [filters.SearchFilter, filters.OrderingFilter] search_fields = ['title', 'author', 'site'] def get_queryset(self): queryset = super().get_queryset() task_id = self.request.query_params.get('task') if task_id and task_id.isdigit(): queryset = queryset.filter(task_id=int(task_id)) # Python 层面单任务去重 seen_titles = set() unique_queryset = [] for obj in queryset: if obj.title not in seen_titles: unique_queryset.append(obj) seen_titles.add(obj.title) return unique_queryset return queryset def create(self, request, *args, **kwargs): """ 在原生 create 接口中实现单任务增量插入 """ task_id = request.data.get('task_id') if not task_id: return Response({"detail": "缺少 task_id"}, status=status.HTTP_400_BAD_REQUEST) data_list = request.data.get('data', []) if not data_list: return Response({"detail": "缺少 data"}, status=status.HTTP_400_BAD_REQUEST) added_count = 0 skipped_titles = [] for data in data_list: title = data.get('title') if not title: continue # 判断同一任务下是否已存在 if TaskDetail.objects.filter(task_id=task_id, title=title).exists(): skipped_titles.append(title) continue # 不存在则创建 serializer = self.get_serializer(data={**data, "task_id": task_id}) serializer.is_valid(raise_exception=True) serializer.save() added_count += 1 return Response({ "added_count": added_count, "skipped_titles": skipped_titles }, status=status.HTTP_201_CREATED)