MindMap/backend/mindmap/views_doc.py

611 lines
21 KiB
Python
Raw Normal View History

from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework import status
from django.db import transaction
from django.utils import timezone
2025-09-08 10:20:48 +00:00
from django.http import StreamingHttpResponse
import json
from .models import mindMap, Node
from .serializers import map_mindmap_to_doc, map_node_to_doc
@api_view(['GET'])
def get_mindmap(request, id):
try:
m = mindMap.objects.get(id=id, deleted=False)
except mindMap.DoesNotExist:
return Response({'detail': 'mindMap not found'}, status=404)
# 获取所有节点,包括被删除的节点
nodes = Node.objects.filter(mindmap=m, deleted=False).order_by('created_at')
# 转换为MindElixir期望的树形结构
mindelixir_data = convert_to_mindelixir_format(m, list(nodes))
return Response(mindelixir_data)
def convert_to_mindelixir_format(mindmap, nodes):
"""将扁平节点列表转换为MindElixir期望的树形结构"""
if not nodes:
return {
"nodeData": {
"id": "root",
"topic": "根节点",
"children": []
}
}
# 创建节点映射
node_map = {}
for node in nodes:
node_map[str(node.id)] = {
"id": str(node.id),
"topic": node.title or "无标题",
"data": {
"des": node.desc or ""
},
"children": [],
# 添加思维导图ID到节点
"mindmapId": mindmap.id,
"mindmap_id": mindmap.id
}
# 构建树形结构
root_nodes = []
for node in nodes:
mindelixir_node = node_map[str(node.id)]
if node.parent_id and str(node.parent_id) in node_map:
# 有父节点添加到父节点的children中
parent = node_map[str(node.parent_id)]
parent["children"].append(mindelixir_node)
else:
# 没有父节点,作为根节点
# 过滤掉空的根节点(标题为"根节点标题"且没有子节点的节点)
if not (node.title == "根节点标题" and node.children_count == 0):
root_nodes.append(mindelixir_node)
# 如果只有一个根节点,直接返回它
if len(root_nodes) == 1:
return {"nodeData": root_nodes[0]}
elif len(root_nodes) > 1:
# 如果有多个根节点,创建一个虚拟根节点
return {
"nodeData": {
"id": "root",
"topic": mindmap.file_name or "思维导图",
"children": root_nodes
}
}
else:
# 没有节点,返回默认根节点
return {
"nodeData": {
"id": "root",
"topic": "根节点",
"children": []
}
}
# @api_view(['POST'])
# @transaction.atomic
# def create_mindmap(request):
# title = (request.data or {}).get('title') or '思维导图'
# m = mindMap.objects.create(file_name=title)
# root = Node.objects.create(
# mindmap=m,
# title='根节点标题',
# desc='',
# is_root=True,
# parent_id=None,
# children_count=0,
# depth=0,
# deleted=False,
# )
# # 使用新的MindElixir格式
# mindelixir_data = convert_to_mindelixir_format(m, [root])
# return Response(mindelixir_data, status=201)
# @api_view(['POST'])
# @transaction.atomic
# def create_mindmap(request):
# title = (request.data or {}).get('title') or '思维导图'
# m = mindMap.objects.create(file_name=title)
# root = Node.objects.create(
# mindmap=m,
# title='根节点标题',
# desc='',
# is_root=True,
# parent_id=None,
# children_count=0,
# depth=0,
# deleted=False,
# )
# # 使用新的MindElixir格式
# mindelixir_data = convert_to_mindelixir_format(m, [root])
# # 在返回数据中添加ID
# if isinstance(mindelixir_data, dict):
# mindelixir_data['id'] = m.id
# mindelixir_data['title'] = m.file_name
# else:
# mindelixir_data = {
# 'id': m.id,
# 'title': m.file_name,
# 'nodeData': mindelixir_data
# }
# return Response(mindelixir_data, status=201)
@api_view(['POST'])
@transaction.atomic
def create_mindmap(request):
title = (request.data or {}).get('title') or '思维导图'
m = mindMap.objects.create(file_name=title)
# 获取传入的思维导图数据
mindmap_data = request.data.get('data')
if mindmap_data and isinstance(mindmap_data, dict):
# 创建根节点
root = Node.objects.create(
mindmap=m,
title=mindmap_data.get('topic', '根节点'),
desc=mindmap_data.get('des', ''),
is_root=True,
parent_id=None,
children_count=0,
depth=0,
deleted=False,
)
# 递归创建所有子节点
if mindmap_data.get('children'):
create_nodes_recursively(mindmap_data['children'], m, root.id)
# 更新根节点的children_count
root.children_count = Node.objects.filter(parent_id=root.id, deleted=False).count()
root.save()
else:
# 创建空根节点
root = Node.objects.create(
mindmap=m,
title='根节点标题',
desc='',
is_root=True,
parent_id=None,
children_count=0,
depth=0,
deleted=False,
)
# 返回完整数据
mindelixir_data = convert_to_mindelixir_format(m, [root])
response_data = {
'id': m.id,
'title': m.file_name,
'nodeData': mindelixir_data.get('nodeData', mindelixir_data)
}
return Response(response_data, status=201)
def create_nodes_recursively(nodes_data, mindmap, parent_id):
"""递归创建节点"""
for node_data in nodes_data:
# 创建当前节点
node = Node.objects.create(
mindmap=mindmap,
title=node_data.get('topic', '节点'),
desc=node_data.get('des', ''),
is_root=False,
parent_id=parent_id,
children_count=len(node_data.get('children', [])),
depth=1, # 可以根据实际层级计算
deleted=False,
)
# 递归创建子节点
if node_data.get('children'):
create_nodes_recursively(node_data['children'], mindmap, node.id)
# 更新父节点的children_count
if parent_id:
try:
parent = Node.objects.get(id=parent_id, deleted=False)
parent.children_count = Node.objects.filter(parent_id=parent_id, deleted=False).count()
parent.save()
except Node.DoesNotExist:
pass
@api_view(['POST'])
@transaction.atomic
def add_nodes(request):
data = request.data or {}
mindmap_id = data.get('mindMapId')
nodes_payload = data.get('nodes', [])
if not mindmap_id:
return Response({'detail': 'mindMapId is required'}, status=400)
try:
# 确保mindMapId是整数
mindmap_id = int(mindmap_id)
m = mindMap.objects.get(id=mindmap_id, deleted=False)
except (ValueError, TypeError):
return Response({'detail': 'mindMapId must be a valid integer'}, status=400)
except mindMap.DoesNotExist:
return Response({'detail': 'mindMap not found'}, status=404)
# 处理单个节点或节点数组
if isinstance(nodes_payload, dict):
# 单个节点对象
nodes_payload = [nodes_payload]
elif not isinstance(nodes_payload, list):
return Response({'detail': 'nodes must be an object or array'}, status=400)
created_nodes = []
for n in nodes_payload:
if not isinstance(n, dict):
return Response({'detail': 'each node must be an object'}, status=400)
# 计算深度
depth = 0
parent_id = n.get('parentId')
if parent_id and not bool(n.get('isRoot', False)):
try:
parent_node = Node.objects.get(id=parent_id, deleted=False, mindmap=m)
depth = parent_node.depth + 1
except Node.DoesNotExist:
return Response({'detail': f'parent node {parent_id} not found'}, status=404)
node = Node.objects.create(
mindmap=m,
title=n.get('title') or '',
desc=n.get('des') or '',
is_root=bool(n.get('isRoot', False)),
parent_id=parent_id,
children_count=0, # 新节点初始子节点数为0
depth=depth,
deleted=False,
)
# 更新父节点的子节点计数
if node.parent_id and not node.is_root:
try:
parent_node = Node.objects.get(id=node.parent_id, deleted=False, mindmap=m)
parent_node.children_count += 1
parent_node.save()
except Node.DoesNotExist:
pass # 父节点可能已被删除
created_nodes.append(node)
resp_nodes = [map_node_to_doc(x) for x in created_nodes]
return Response({
'success': True,
'message': f'成功创建 {len(created_nodes)} 个节点',
'data': {
'mindMapId': str(m.id),
'nodes': resp_nodes
}
})
@api_view(['PATCH'])
@transaction.atomic
def update_node(request):
body = request.data or {}
node_id = body.get('id')
if not node_id:
return Response({'detail': 'id is required'}, status=400)
try:
node = Node.objects.get(id=node_id, deleted=False)
except Node.DoesNotExist:
return Response({'detail': 'node not found'}, status=404)
# 记录更新了哪些字段
updated_fields = []
if 'newTitle' in body:
node.title = body.get('newTitle') or ''
updated_fields.append('title')
if 'newDes' in body:
node.desc = body.get('newDes') or ''
updated_fields.append('des')
if 'newParentId' in body:
new_parent_id = body.get('newParentId')
old_parent_id = node.parent_id
# 验证父节点是否存在如果提供了父节点ID
if new_parent_id:
try:
parent_node = Node.objects.get(id=new_parent_id, deleted=False, mindmap=node.mindmap)
# 更新父节点ID和深度
node.parent_id = new_parent_id
node.depth = parent_node.depth + 1
updated_fields.extend(['parentId', 'depth'])
# 更新新父节点的子节点计数
parent_node.children_count += 1
parent_node.save()
except Node.DoesNotExist:
return Response({'detail': 'parent node not found'}, status=404)
else:
# 如果newParentId为null表示设置为根节点
node.parent_id = None
node.depth = 0
updated_fields.extend(['parentId', 'depth'])
# 更新原父节点的子节点计数(如果存在且不是同一个父节点)
if old_parent_id and old_parent_id != new_parent_id:
try:
old_parent = Node.objects.get(id=old_parent_id, deleted=False, mindmap=node.mindmap)
old_parent.children_count = max(0, old_parent.children_count - 1)
old_parent.save()
except Node.DoesNotExist:
pass # 原父节点可能已被删除
# 只有在有字段更新时才更新时间戳
if updated_fields:
node.updated_at = timezone.now()
node.save()
return Response({
'success': True,
'message': '节点更新成功',
'data': map_node_to_doc(node),
'updatedFields': updated_fields
})
def _collect_subtree_ids(mindmap_id: int, start_ids: list[str]) -> set[str]:
ids = set(start_ids)
queue = list(start_ids)
while queue:
children = Node.objects.filter(
mindmap_id=mindmap_id,
parent_id__in=queue,
deleted=False
).values_list('id', flat=True)
new = [str(c) for c in children if str(c) not in ids]
if not new:
break
ids.update(new)
queue = new
return ids
@api_view(['DELETE'])
@transaction.atomic
def delete_nodes(request):
body = request.data or {}
node_ids = body.get('nodeIds', [])
if not node_ids:
return Response({'detail': 'nodeIds is required'}, status=400)
first = Node.objects.filter(id__in=node_ids, deleted=False).first()
if not first:
return Response({'success': True, 'message': '无可删除节点', 'data': {'deletedCount': 0, 'deletedNodeIds': []}})
# 收集所有要删除的节点ID包括子树
all_ids = _collect_subtree_ids(first.mindmap_id, [str(x) for x in node_ids])
# 获取要删除的节点用于更新父节点的childrenCount
nodes_to_delete = Node.objects.filter(id__in=list(all_ids), deleted=False)
# 更新父节点的childrenCount
parent_updates = {}
for node in nodes_to_delete:
if node.parent_id:
if node.parent_id not in parent_updates:
parent_updates[node.parent_id] = 0
parent_updates[node.parent_id] += 1
# 批量更新父节点的childrenCount
for parent_id, count_reduction in parent_updates.items():
try:
parent = Node.objects.get(id=parent_id, deleted=False)
parent.children_count = max(0, parent.children_count - count_reduction)
parent.save()
except Node.DoesNotExist:
pass # 父节点可能已被删除
# 软删除节点
updated = Node.objects.filter(id__in=list(all_ids)).update(deleted=True, updated_at=timezone.now())
return Response({
'success': True,
'message': '节点删除成功',
'data': {
'deletedCount': int(updated),
'deletedNodeIds': list(all_ids)
}
})
@api_view(['POST'])
def generate_markdown(request):
"""AI生成Markdown接口"""
try:
data = request.data
system_prompt = data.get('system_prompt', '')
user_prompt = data.get('user_prompt', '')
model = data.get('model', 'glm-4.5')
base_url = data.get('base_url', 'https://open.bigmodel.cn/api/paas/v4/')
api_key = data.get('api_key', '')
if not user_prompt:
return Response({'error': '用户提示词不能为空'}, status=400)
# 导入AI服务
from .ai_service import call_ai_api
# 调用AI API
markdown_content = call_ai_api(system_prompt, user_prompt, model, base_url, api_key)
if markdown_content:
return Response({
'markdown': markdown_content,
'success': True
})
else:
# 如果AI API调用失败返回错误信息
return Response({
'error': 'AI API调用失败',
'success': False
}, status=500)
except Exception as e:
return Response({
'error': str(e),
'success': False
}, status=500)
2025-09-08 10:20:48 +00:00
@api_view(['POST', 'OPTIONS'])
def generate_ai_content_stream(request):
"""
流式生成AI内容
"""
# 处理OPTIONS请求CORS预检请求
if request.method == 'OPTIONS':
response = Response()
response['Access-Control-Allow-Origin'] = '*'
response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With'
response['Access-Control-Allow-Credentials'] = 'true'
response['Access-Control-Max-Age'] = '86400'
return response
try:
data = request.data
user_prompt = data.get('user_prompt', '')
system_prompt = data.get('system_prompt', '你是一个专业的思维导图内容生成助手请根据用户的需求生成结构化的Markdown内容。')
model = data.get('model', 'glm-4.5')
base_url = data.get('base_url', 'https://open.bigmodel.cn/api/paas/v4/')
api_key = data.get('api_key', '')
if not user_prompt:
return Response({'error': '用户提示词不能为空'}, status=400)
# 导入AI服务
from .ai_service import call_ai_api
def generate_stream():
try:
print(f"开始调用流式AI API...")
# 调用流式AI API
stream = call_ai_api(system_prompt, user_prompt, model, base_url, api_key, stream=True)
if stream is None:
print("AI API返回None发送错误信号")
yield f"data: {json.dumps({'type': 'error', 'content': 'AI API调用失败'})}\n\n"
return
print("开始发送流式数据...")
# 发送开始信号
yield f"data: {json.dumps({'type': 'start', 'content': ''})}\n\n"
# 发送流式内容
chunk_count = 0
for chunk in stream:
if chunk:
chunk_count += 1
print(f"发送第{chunk_count}个数据块: {chunk[:50]}...")
yield f"data: {json.dumps({'type': 'chunk', 'content': chunk})}\n\n"
print(f"流式数据发送完成,总共{chunk_count}个数据块")
# 发送结束信号
yield f"data: {json.dumps({'type': 'end', 'content': ''})}\n\n"
except Exception as e:
print(f"流式生成过程中发生错误: {e}")
import traceback
traceback.print_exc()
# 发送错误信号
yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
response = StreamingHttpResponse(
generate_stream(),
content_type='text/event-stream'
)
# 修复CORS配置移除不允许的头部
response['Cache-Control'] = 'no-cache'
response['Access-Control-Allow-Origin'] = '*'
response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With'
response['Access-Control-Allow-Credentials'] = 'true'
response['Access-Control-Max-Age'] = '86400'
return response
except Exception as e:
print(f"流式API处理过程中发生错误: {e}")
import traceback
traceback.print_exc()
return Response({
'error': str(e),
'success': False
}, status=500)
@api_view(['POST', 'OPTIONS'])
def test_stream(request):
"""
测试流式响应
"""
# 处理OPTIONS请求CORS预检请求
if request.method == 'OPTIONS':
response = Response()
response['Access-Control-Allow-Origin'] = '*'
response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With'
response['Access-Control-Allow-Credentials'] = 'true'
response['Access-Control-Max-Age'] = '86400'
return response
def generate_test_stream():
try:
# 发送开始信号
yield f"data: {json.dumps({'type': 'start', 'content': ''})}\n\n"
# 发送测试数据
test_content = "# 测试思维导图\n\n## 主要主题\n- 主题1\n- 主题2\n\n## 详细内容\n- 内容1\n- 内容2"
for i, char in enumerate(test_content):
yield f"data: {json.dumps({'type': 'chunk', 'content': char})}\n\n"
# 添加小延迟模拟流式效果
import time
time.sleep(0.01)
# 发送结束信号
yield f"data: {json.dumps({'type': 'end', 'content': ''})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
response = StreamingHttpResponse(
generate_test_stream(),
content_type='text/event-stream'
)
# 修复CORS配置移除不允许的头部
response['Cache-Control'] = 'no-cache'
response['Access-Control-Allow-Origin'] = '*'
response['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, Cache-Control, X-Requested-With'
response['Access-Control-Allow-Credentials'] = 'true'
response['Access-Control-Max-Age'] = '86400'
return response