from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework import status from django.db import transaction from django.utils import timezone from django.http import StreamingHttpResponse import json from .models import mindMap, Node from .serializers import map_mindmap_to_doc, map_node_to_doc @api_view(['GET']) def get_mindmap(request, id): try: m = mindMap.objects.get(id=id, deleted=False) except mindMap.DoesNotExist: return Response({'detail': 'mindMap not found'}, status=404) # 获取所有节点,包括被删除的节点 nodes = Node.objects.filter(mindmap=m, deleted=False).order_by('created_at') # 转换为MindElixir期望的树形结构 mindelixir_data = convert_to_mindelixir_format(m, list(nodes)) return Response(mindelixir_data) def convert_to_mindelixir_format(mindmap, nodes): """将扁平节点列表转换为MindElixir期望的树形结构""" if not nodes: return { "nodeData": { "id": "root", "topic": "根节点", "children": [] } } # 创建节点映射 node_map = {} for node in nodes: node_map[str(node.id)] = { "id": str(node.id), "topic": node.title or "无标题", "data": { "des": node.desc or "" }, "children": [], # 添加思维导图ID到节点 "mindmapId": mindmap.id, "mindmap_id": mindmap.id } # 构建树形结构 root_nodes = [] for node in nodes: mindelixir_node = node_map[str(node.id)] if node.parent_id and str(node.parent_id) in node_map: # 有父节点,添加到父节点的children中 parent = node_map[str(node.parent_id)] parent["children"].append(mindelixir_node) else: # 没有父节点,作为根节点 # 过滤掉空的根节点(标题为"根节点标题"且没有子节点的节点) if not (node.title == "根节点标题" and node.children_count == 0): root_nodes.append(mindelixir_node) # 如果只有一个根节点,直接返回它 if len(root_nodes) == 1: return {"nodeData": root_nodes[0]} elif len(root_nodes) > 1: # 如果有多个根节点,创建一个虚拟根节点 return { "nodeData": { "id": "root", "topic": mindmap.file_name or "思维导图", "children": root_nodes } } else: # 没有节点,返回默认根节点 return { "nodeData": { "id": "root", "topic": "根节点", "children": [] } } # @api_view(['POST']) # @transaction.atomic # def create_mindmap(request): # title = (request.data or {}).get('title') or '思维导图' # m = mindMap.objects.create(file_name=title) # root = Node.objects.create( # mindmap=m, # title='根节点标题', # desc='', # is_root=True, # parent_id=None, # children_count=0, # depth=0, # deleted=False, # ) # # 使用新的MindElixir格式 # mindelixir_data = convert_to_mindelixir_format(m, [root]) # return Response(mindelixir_data, status=201) # @api_view(['POST']) # @transaction.atomic # def create_mindmap(request): # title = (request.data or {}).get('title') or '思维导图' # m = mindMap.objects.create(file_name=title) # root = Node.objects.create( # mindmap=m, # title='根节点标题', # desc='', # is_root=True, # parent_id=None, # children_count=0, # depth=0, # deleted=False, # ) # # 使用新的MindElixir格式 # mindelixir_data = convert_to_mindelixir_format(m, [root]) # # 在返回数据中添加ID # if isinstance(mindelixir_data, dict): # mindelixir_data['id'] = m.id # mindelixir_data['title'] = m.file_name # else: # mindelixir_data = { # 'id': m.id, # 'title': m.file_name, # 'nodeData': mindelixir_data # } # return Response(mindelixir_data, status=201) @api_view(['POST']) @transaction.atomic def create_mindmap(request): title = (request.data or {}).get('title') or '思维导图' m = mindMap.objects.create(file_name=title) # 获取传入的思维导图数据 mindmap_data = request.data.get('data') if mindmap_data and isinstance(mindmap_data, dict): # 创建根节点 root = Node.objects.create( mindmap=m, title=mindmap_data.get('topic', '根节点'), desc=mindmap_data.get('des', ''), is_root=True, parent_id=None, children_count=0, depth=0, deleted=False, ) # 递归创建所有子节点 if mindmap_data.get('children'): create_nodes_recursively(mindmap_data['children'], m, root.id) # 更新根节点的children_count root.children_count = Node.objects.filter(parent_id=root.id, deleted=False).count() root.save() else: # 创建空根节点 root = Node.objects.create( mindmap=m, title='根节点标题', desc='', is_root=True, parent_id=None, children_count=0, depth=0, deleted=False, ) # 返回完整数据 mindelixir_data = convert_to_mindelixir_format(m, [root]) response_data = { 'id': m.id, 'title': m.file_name, 'nodeData': mindelixir_data.get('nodeData', mindelixir_data) } return Response(response_data, status=201) def create_nodes_recursively(nodes_data, mindmap, parent_id): """递归创建节点""" for node_data in nodes_data: # 创建当前节点 node = Node.objects.create( mindmap=mindmap, title=node_data.get('topic', '节点'), desc=node_data.get('des', ''), is_root=False, parent_id=parent_id, children_count=len(node_data.get('children', [])), depth=1, # 可以根据实际层级计算 deleted=False, ) # 递归创建子节点 if node_data.get('children'): create_nodes_recursively(node_data['children'], mindmap, node.id) # 更新父节点的children_count if parent_id: try: parent = Node.objects.get(id=parent_id, deleted=False) parent.children_count = Node.objects.filter(parent_id=parent_id, deleted=False).count() parent.save() except Node.DoesNotExist: pass @api_view(['POST']) @transaction.atomic def add_nodes(request): data = request.data or {} mindmap_id = data.get('mindMapId') nodes_payload = data.get('nodes', []) if not mindmap_id: return Response({'detail': 'mindMapId is required'}, status=400) try: # 确保mindMapId是整数 mindmap_id = int(mindmap_id) m = mindMap.objects.get(id=mindmap_id, deleted=False) except (ValueError, TypeError): return Response({'detail': 'mindMapId must be a valid integer'}, status=400) except mindMap.DoesNotExist: return Response({'detail': 'mindMap not found'}, status=404) # 处理单个节点或节点数组 if isinstance(nodes_payload, dict): # 单个节点对象 nodes_payload = [nodes_payload] elif not isinstance(nodes_payload, list): return Response({'detail': 'nodes must be an object or array'}, status=400) created_nodes = [] for n in nodes_payload: if not isinstance(n, dict): return Response({'detail': 'each node must be an object'}, status=400) # 计算深度 depth = 0 parent_id = n.get('parentId') if parent_id and not bool(n.get('isRoot', False)): try: parent_node = Node.objects.get(id=parent_id, deleted=False, mindmap=m) depth = parent_node.depth + 1 except Node.DoesNotExist: return Response({'detail': f'parent node {parent_id} not found'}, status=404) node = Node.objects.create( mindmap=m, title=n.get('title') or '', desc=n.get('des') or '', is_root=bool(n.get('isRoot', False)), parent_id=parent_id, children_count=0, # 新节点初始子节点数为0 depth=depth, deleted=False, ) # 更新父节点的子节点计数 if node.parent_id and not node.is_root: try: parent_node = Node.objects.get(id=node.parent_id, deleted=False, mindmap=m) parent_node.children_count += 1 parent_node.save() except Node.DoesNotExist: pass # 父节点可能已被删除 created_nodes.append(node) resp_nodes = [map_node_to_doc(x) for x in created_nodes] return Response({ 'success': True, 'message': f'成功创建 {len(created_nodes)} 个节点', 'data': { 'mindMapId': str(m.id), 'nodes': resp_nodes } }) @api_view(['PATCH']) @transaction.atomic def update_node(request): body = request.data or {} node_id = body.get('id') if not node_id: return Response({'detail': 'id is required'}, status=400) try: node = Node.objects.get(id=node_id, deleted=False) except Node.DoesNotExist: return Response({'detail': 'node not found'}, status=404) # 记录更新了哪些字段 updated_fields = [] if 'newTitle' in body: node.title = body.get('newTitle') or '' updated_fields.append('title') if 'newDes' in body: node.desc = body.get('newDes') or '' updated_fields.append('des') if 'newParentId' in body: new_parent_id = body.get('newParentId') old_parent_id = node.parent_id # 验证父节点是否存在(如果提供了父节点ID) if new_parent_id: try: parent_node = Node.objects.get(id=new_parent_id, deleted=False, mindmap=node.mindmap) # 更新父节点ID和深度 node.parent_id = new_parent_id node.depth = parent_node.depth + 1 updated_fields.extend(['parentId', 'depth']) # 更新新父节点的子节点计数 parent_node.children_count += 1 parent_node.save() except Node.DoesNotExist: return Response({'detail': 'parent node not found'}, status=404) else: # 如果newParentId为null,表示设置为根节点 node.parent_id = None node.depth = 0 updated_fields.extend(['parentId', 'depth']) # 更新原父节点的子节点计数(如果存在且不是同一个父节点) if old_parent_id and old_parent_id != new_parent_id: try: old_parent = Node.objects.get(id=old_parent_id, deleted=False, mindmap=node.mindmap) old_parent.children_count = max(0, old_parent.children_count - 1) old_parent.save() except Node.DoesNotExist: pass # 原父节点可能已被删除 # 只有在有字段更新时才更新时间戳 if updated_fields: node.updated_at = timezone.now() node.save() return Response({ 'success': True, 'message': '节点更新成功', 'data': map_node_to_doc(node), 'updatedFields': updated_fields }) def _collect_subtree_ids(mindmap_id: int, start_ids: list[str]) -> set[str]: ids = set(start_ids) queue = list(start_ids) while queue: children = Node.objects.filter( mindmap_id=mindmap_id, parent_id__in=queue, deleted=False ).values_list('id', flat=True) new = [str(c) for c in children if str(c) not in ids] if not new: break ids.update(new) queue = new return ids @api_view(['DELETE']) @transaction.atomic def delete_nodes(request): body = request.data or {} node_ids = body.get('nodeIds', []) if not node_ids: return Response({'detail': 'nodeIds is required'}, status=400) first = Node.objects.filter(id__in=node_ids, deleted=False).first() if not first: return Response({'success': True, 'message': '无可删除节点', 'data': {'deletedCount': 0, 'deletedNodeIds': []}}) # 收集所有要删除的节点ID(包括子树) all_ids = _collect_subtree_ids(first.mindmap_id, [str(x) for x in node_ids]) # 获取要删除的节点,用于更新父节点的childrenCount nodes_to_delete = Node.objects.filter(id__in=list(all_ids), deleted=False) # 更新父节点的childrenCount parent_updates = {} for node in nodes_to_delete: if node.parent_id: if node.parent_id not in parent_updates: parent_updates[node.parent_id] = 0 parent_updates[node.parent_id] += 1 # 批量更新父节点的childrenCount for parent_id, count_reduction in parent_updates.items(): try: parent = Node.objects.get(id=parent_id, deleted=False) parent.children_count = max(0, parent.children_count - count_reduction) parent.save() except Node.DoesNotExist: pass # 父节点可能已被删除 # 软删除节点 updated = Node.objects.filter(id__in=list(all_ids)).update(deleted=True, updated_at=timezone.now()) return Response({ 'success': True, 'message': '节点删除成功', 'data': { 'deletedCount': int(updated), 'deletedNodeIds': list(all_ids) } }) @api_view(['POST']) def generate_markdown(request): """AI生成Markdown接口""" try: data = request.data system_prompt = data.get('system_prompt', '') user_prompt = data.get('user_prompt', '') model = data.get('model', 'glm-4.5') base_url = data.get('base_url', 'https://open.bigmodel.cn/api/paas/v4/') api_key = data.get('api_key', '') if not user_prompt: return Response({'error': '用户提示词不能为空'}, status=400) # 导入AI服务 from .ai_service import call_ai_api # 调用AI API markdown_content = call_ai_api(system_prompt, user_prompt, model, base_url, api_key) if markdown_content: return Response({ 'markdown': markdown_content, 'success': True }) else: # 如果AI API调用失败,返回错误信息 return Response({ 'error': 'AI API调用失败', 'success': False }, status=500) except Exception as e: return Response({ 'error': str(e), 'success': False }, status=500) @api_view(['POST', 'OPTIONS']) def generate_ai_content_stream(request): """ 流式生成AI内容 """ # 处理OPTIONS请求(CORS预检请求) if request.method == 'OPTIONS': response = Response() response['Access-Control-Allow-Origin'] = '*' response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With' response['Access-Control-Allow-Credentials'] = 'true' response['Access-Control-Max-Age'] = '86400' return response try: data = request.data user_prompt = data.get('user_prompt', '') system_prompt = data.get('system_prompt', '你是一个专业的思维导图内容生成助手,请根据用户的需求生成结构化的Markdown内容。') model = data.get('model', 'glm-4.5') base_url = data.get('base_url', 'https://open.bigmodel.cn/api/paas/v4/') api_key = data.get('api_key', '') if not user_prompt: return Response({'error': '用户提示词不能为空'}, status=400) # 导入AI服务 from .ai_service import call_ai_api def generate_stream(): try: print(f"开始调用流式AI API...") # 调用流式AI API stream = call_ai_api(system_prompt, user_prompt, model, base_url, api_key, stream=True) if stream is None: print("AI API返回None,发送错误信号") yield f"data: {json.dumps({'type': 'error', 'content': 'AI API调用失败'})}\n\n" return print("开始发送流式数据...") # 发送开始信号 yield f"data: {json.dumps({'type': 'start', 'content': ''})}\n\n" # 发送流式内容 chunk_count = 0 for chunk in stream: if chunk: chunk_count += 1 print(f"发送第{chunk_count}个数据块: {chunk[:50]}...") yield f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n" print(f"流式数据发送完成,总共{chunk_count}个数据块") # 发送结束信号 yield f"data: {json.dumps({'type': 'end', 'content': ''})}\n\n" except Exception as e: print(f"流式生成过程中发生错误: {e}") import traceback traceback.print_exc() # 发送错误信号 yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" response = StreamingHttpResponse( generate_stream(), content_type='text/event-stream' ) # 修复CORS配置,移除不允许的头部 response['Cache-Control'] = 'no-cache' response['Access-Control-Allow-Origin'] = '*' response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With' response['Access-Control-Allow-Credentials'] = 'true' response['Access-Control-Max-Age'] = '86400' return response except Exception as e: print(f"流式API处理过程中发生错误: {e}") import traceback traceback.print_exc() return Response({ 'error': str(e), 'success': False }, status=500) @api_view(['POST', 'OPTIONS']) def test_stream(request): """ 测试流式响应 """ # 处理OPTIONS请求(CORS预检请求) if request.method == 'OPTIONS': response = Response() response['Access-Control-Allow-Origin'] = '*' response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With' response['Access-Control-Allow-Credentials'] = 'true' response['Access-Control-Max-Age'] = '86400' return response def generate_test_stream(): try: # 发送开始信号 yield f"data: {json.dumps({'type': 'start', 'content': ''})}\n\n" # 发送测试数据 test_content = "# 测试思维导图\n\n## 主要主题\n- 主题1\n- 主题2\n\n## 详细内容\n- 内容1\n- 内容2" for i, char in enumerate(test_content): yield f"data: {json.dumps({'type': 'chunk', 'content': char})}\n\n" # 添加小延迟模拟流式效果 import time time.sleep(0.01) # 发送结束信号 yield f"data: {json.dumps({'type': 'end', 'content': ''})}\n\n" except Exception as e: yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" response = StreamingHttpResponse( generate_test_stream(), content_type='text/event-stream' ) # 修复CORS配置,移除不允许的头部 response['Cache-Control'] = 'no-cache' response['Access-Control-Allow-Origin'] = '*' response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS' response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With' response['Access-Control-Allow-Credentials'] = 'true' response['Access-Control-Max-Age'] = '86400' return response