|
@@ -0,0 +1,1669 @@
|
|
|
+import logging
|
|
|
+import platform
|
|
|
+import shutil
|
|
|
+import subprocess
|
|
|
+import tempfile
|
|
|
+import threading
|
|
|
+import uuid
|
|
|
+import warnings
|
|
|
+from datetime import timedelta, datetime, timezone
|
|
|
+from io import BytesIO
|
|
|
+from time import sleep
|
|
|
+from typing import Collection
|
|
|
+
|
|
|
+import PyPDF2
|
|
|
+import chardet
|
|
|
+import cv2
|
|
|
+import fitz
|
|
|
+import pandas as pd
|
|
|
+import redis
|
|
|
+import requests
|
|
|
+from PIL import Image
|
|
|
+from bs4 import BeautifulSoup
|
|
|
+from django.core.exceptions import ObjectDoesNotExist
|
|
|
+from django.core.paginator import Paginator, PageNotAnInteger, EmptyPage
|
|
|
+from django.db import transaction
|
|
|
+from django.shortcuts import get_object_or_404
|
|
|
+from minio import Minio
|
|
|
+import re
|
|
|
+import os
|
|
|
+
|
|
|
+import markdown
|
|
|
+from paddleocr import PaddleOCR
|
|
|
+from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
|
|
|
+from requests import RequestException
|
|
|
+from scipy import interpolate
|
|
|
+from tabulate import tabulate
|
|
|
+from DCbackend.settings import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION, MILVUS_USER, MILVUS_PASSWORD
|
|
|
+from base import logger
|
|
|
+
|
|
|
+os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
|
|
|
+from scipy.spatial.distance import cosine
|
|
|
+from django.utils import timezone
|
|
|
+import traceback
|
|
|
+from DCbackend.utils.common import success, fail
|
|
|
+from backend.Service.MinioService import MinioService
|
|
|
+from backend.models import Knowledgebase, DocumentKbm, File2document, File, Task, TaskSublist, KbmDocumentType
|
|
|
+import pytesseract
|
|
|
+import numpy as np
|
|
|
+import json
|
|
|
+import time
|
|
|
+import pika
|
|
|
+from django.db.models import Count, Case, When, IntegerField, Q, Max
|
|
|
+from django.conf import settings
|
|
|
+
|
|
|
+
|
|
|
+class DocumentQueue:
|
|
|
+ QUEUE_KEY = "document_process_queue"
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+minio_client = Minio(
|
|
|
+ settings.MINIO_ENDPOINT,
|
|
|
+ access_key=settings.MINIO_ACCESS_KEY,
|
|
|
+ secret_key=settings.MINIO_SECRET_KEY,
|
|
|
+ secure=settings.MINIO_SECURE
|
|
|
+)
|
|
|
+
|
|
|
+if os.name == 'nt': # Windows
|
|
|
+ pytesseract.pytesseract.tesseract_cmd = r'D:\Program Files\OCR\tesseract.exe'
|
|
|
+else: # macOS 或 Linux
|
|
|
+ pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class KbmService:
|
|
|
+ bert_model = None
|
|
|
+ bert_tokenizer = None
|
|
|
+ #大模型地址
|
|
|
+ API_URL = "http://127.0.0.1:11434/api/embeddings"
|
|
|
+
|
|
|
+ #通用rabbitmq 放入队列
|
|
|
+ def send_to_rabbitmq(queue_name, message):
|
|
|
+ """
|
|
|
+ 将消息发送到指定的RabbitMQ队列
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ connection = pika.BlockingConnection(pika.ConnectionParameters(
|
|
|
+ host=settings.RABBITMQ_HOST,
|
|
|
+ port=settings.RABBITMQ_PORT,
|
|
|
+ credentials=pika.PlainCredentials(
|
|
|
+ settings.RABBITMQ_USER,
|
|
|
+ settings.RABBITMQ_PASSWORD
|
|
|
+ )
|
|
|
+ ))
|
|
|
+ channel = connection.channel()
|
|
|
+
|
|
|
+ channel.queue_declare(queue=queue_name, durable=True)
|
|
|
+
|
|
|
+ channel.basic_publish(
|
|
|
+ exchange='',
|
|
|
+ routing_key=queue_name,
|
|
|
+ body=json.dumps(message),
|
|
|
+ properties=pika.BasicProperties(
|
|
|
+ delivery_mode=2, # 使消息持久化
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ connection.close()
|
|
|
+ logger.info(f"消息已发送到队列 {queue_name}")
|
|
|
+ return True
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"发送消息到RabbitMQ时出错: {str(e)}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def selectBucketInfo(request):
|
|
|
+ # user_id = request.POST.get("user_id")
|
|
|
+ knowledgebases = Knowledgebase.objects.filter().exclude(status=4).order_by('-create_time').values('id', 'create_time', 'name', 'doc_num', 'description')
|
|
|
+
|
|
|
+ result = []
|
|
|
+ for kb in knowledgebases:
|
|
|
+ # 使用一次查询获取所有计数
|
|
|
+ counts = DocumentKbm.objects.filter(kb_id=kb['id']).aggregate(
|
|
|
+ word_count=Count(Case(When(type__in=['doc', 'docx'], then=1), output_field=IntegerField())),
|
|
|
+ pdf_count=Count(Case(When(type='pdf', then=1), output_field=IntegerField())),
|
|
|
+ excel_count=Count(Case(When(type__in=['xls', 'xlsx'], then=1), output_field=IntegerField()))
|
|
|
+ )
|
|
|
+
|
|
|
+ kb_data = kb.copy() # 创建 kb 的副本,以避免修改原始数据
|
|
|
+ kb_data.update(counts)
|
|
|
+ result.append(kb_data)
|
|
|
+
|
|
|
+ return success(result)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def getFileInfo(request):
|
|
|
+ try:
|
|
|
+ bucket_id = request.POST.get("bucket_id")
|
|
|
+ page = request.POST.get("page", 1)
|
|
|
+ per_page = request.POST.get("pageSize", 10)
|
|
|
+ object_name = request.POST.get("object_name", "")
|
|
|
+ run = request.POST.get("run", "")
|
|
|
+ type = request.POST.get("type", "")
|
|
|
+ doc_type_id = request.POST.get("doc_type_id")
|
|
|
+ if not bucket_id:
|
|
|
+ return fail("bucket_id为空")
|
|
|
+
|
|
|
+
|
|
|
+ # 确保 page 和 per_page 是整数
|
|
|
+ page = int(page)
|
|
|
+ per_page = int(per_page)
|
|
|
+
|
|
|
+ # # 查询文档并排序
|
|
|
+ # documents = DocumentKbm.objects.filter(
|
|
|
+ # Q(kb_id=bucket_id) &
|
|
|
+ # Q(name__icontains=object_name)&
|
|
|
+ # Q(run__icontains=run)&
|
|
|
+ # Q(type__icontains=type)&
|
|
|
+ # Q(doc_type_id=doc_type_id)&
|
|
|
+ # ~Q(status=4)
|
|
|
+ # ).order_by('-create_time')
|
|
|
+
|
|
|
+ # 构建基本查询
|
|
|
+ query = Q(kb_id=bucket_id) & Q(name__icontains=object_name) & Q(run__icontains=run) & Q(type__icontains=type) & ~Q(status=4)
|
|
|
+
|
|
|
+ # 如果 doc_type_id 有值,则添加到查询条件
|
|
|
+ if doc_type_id:
|
|
|
+ query &= Q(doc_type_id=doc_type_id)
|
|
|
+
|
|
|
+ # 查询文档并排序
|
|
|
+ documents = DocumentKbm.objects.filter(query).order_by('-create_time')
|
|
|
+
|
|
|
+ # 创建分页器
|
|
|
+ paginator = Paginator(documents, per_page)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 获取指定页的结果
|
|
|
+ documents_page = paginator.page(page)
|
|
|
+ except PageNotAnInteger:
|
|
|
+ # 如果页码不是整数,返回第一页
|
|
|
+ documents_page = paginator.page(1)
|
|
|
+ except EmptyPage:
|
|
|
+ # 如果页码超出范围,返回最后一页
|
|
|
+ documents_page = paginator.page(paginator.num_pages)
|
|
|
+
|
|
|
+ # 将查询结果转换为列表
|
|
|
+ result = list(documents_page.object_list.values())
|
|
|
+ for info in result:
|
|
|
+ document_id = info['id']
|
|
|
+ max_page = TaskSublist.objects.filter(doc_id=document_id).aggregate(Max('page_number'))['page_number__max']
|
|
|
+ info['max_page'] = max_page if max_page is not None else 0
|
|
|
+
|
|
|
+ pagination_info = {
|
|
|
+ 'total_count': paginator.count,
|
|
|
+ 'total_pages': paginator.num_pages,
|
|
|
+ 'total_size': per_page,
|
|
|
+ 'current_page': documents_page.number,
|
|
|
+ 'has_next': documents_page.has_next(),
|
|
|
+ 'has_previous': documents_page.has_previous()
|
|
|
+ }
|
|
|
+ data = {
|
|
|
+ 'pagination': pagination_info,
|
|
|
+ 'documents': result
|
|
|
+ }
|
|
|
+
|
|
|
+ return success(data)
|
|
|
+ except Exception as e:
|
|
|
+ return fail("获取信息失败")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def updateName(request):
|
|
|
+ try:
|
|
|
+ new_name = request.POST.get("new_name")
|
|
|
+ document_id = request.POST.get("document_id")
|
|
|
+
|
|
|
+ if not new_name or not document_id:
|
|
|
+ return fail("新名称和文件ID不能为空")
|
|
|
+
|
|
|
+ # 获取 DocumentKbm 实例并更新
|
|
|
+ document = get_object_or_404(DocumentKbm, id=document_id)
|
|
|
+ location = document.location
|
|
|
+
|
|
|
+ # 获取原始文件的扩展名
|
|
|
+ _, original_extension = os.path.splitext(location)
|
|
|
+
|
|
|
+ # 检查新名称是否包含扩展名,如果没有则添加原始扩展名
|
|
|
+ _, new_extension = os.path.splitext(new_name)
|
|
|
+ if not new_extension:
|
|
|
+ new_name = f"{new_name}{original_extension}"
|
|
|
+
|
|
|
+ document.name = new_name
|
|
|
+ document.save()
|
|
|
+
|
|
|
+ # 获取关联的 File2document 和 File
|
|
|
+ file2doc = File2document.objects.filter(document_id=document_id).first()
|
|
|
+ if file2doc:
|
|
|
+ file = get_object_or_404(File, id=file2doc.file_id)
|
|
|
+ file.name = new_name
|
|
|
+ file.save()
|
|
|
+ else:
|
|
|
+ # 记录一个警告,因为没有找到关联的 File
|
|
|
+ logger.error(f"Warning: No associated File found for DocumentKbm with id {document_id}")
|
|
|
+
|
|
|
+ return success("文件名更新成功")
|
|
|
+ except ObjectDoesNotExist:
|
|
|
+ return fail("指定的文件或关联文件不存在")
|
|
|
+ except Exception as e:
|
|
|
+ return fail(f"更新文件名失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def deleteDocument(request):
|
|
|
+
|
|
|
+ document_id = request.POST.get("document_id")
|
|
|
+
|
|
|
+ document = get_object_or_404(DocumentKbm, id=document_id)
|
|
|
+ document.status = 4
|
|
|
+ document.save()
|
|
|
+
|
|
|
+ file2doc = File2document.objects.filter(document_id=document_id).first()
|
|
|
+ if file2doc:
|
|
|
+ file = get_object_or_404(File, id=file2doc.file_id)
|
|
|
+ file.status = 4
|
|
|
+ file.save()
|
|
|
+
|
|
|
+ kb_id = document.kb_id
|
|
|
+ new_count = DocumentKbm.objects.filter(kb_id=kb_id).exclude(status=4).count()
|
|
|
+ Knowledgebase.objects.filter(id=kb_id).update(doc_num=new_count)
|
|
|
+ try:
|
|
|
+ # 清理milvus
|
|
|
+ # 连接到 Milvus
|
|
|
+ connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
|
|
+
|
|
|
+ kmb = Knowledgebase.objects.filter(id=kb_id).first()
|
|
|
+ collection = Collection(kmb.location)
|
|
|
+
|
|
|
+ tasks = TaskSublist.objects.filter(doc_id=document.id)
|
|
|
+ for task in tasks:
|
|
|
+ expr = f'id in [{task.milvus_id}]'
|
|
|
+ collection.delete(expr)
|
|
|
+ return success("删除成功")
|
|
|
+ except Exception as e:
|
|
|
+ return fail(f"删除milvus集合时发生错误: {str(e)}")
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # 断开 Milvus 连接
|
|
|
+ connections.disconnect("default")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def getUrl(request):
|
|
|
+ try:
|
|
|
+ document_id = request.POST.get("document_id")
|
|
|
+ if not document_id:
|
|
|
+ return fail("文档ID不能为空")
|
|
|
+
|
|
|
+ # 获取 DocumentKbm 对象
|
|
|
+ document = get_object_or_404(DocumentKbm, id=document_id)
|
|
|
+
|
|
|
+ object_name = document.location
|
|
|
+
|
|
|
+ # 获取对应的 Knowledgebase 对象
|
|
|
+ knowledgebase = get_object_or_404(Knowledgebase, id=document.kb_id)
|
|
|
+ bucket_name = knowledgebase.location
|
|
|
+
|
|
|
+ return MinioService.geturl(object_name, bucket_name)
|
|
|
+
|
|
|
+ except ObjectDoesNotExist:
|
|
|
+ return fail("指定的文档或知识库不存在")
|
|
|
+ except Exception as e:
|
|
|
+ return fail(f"获取URL失败: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+ #新rabbitmq队列
|
|
|
+ @staticmethod
|
|
|
+ def analysis(request):
|
|
|
+ document_id = request.POST.get("document_id")
|
|
|
+ start_page = int(request.POST.get('start_page', 1))
|
|
|
+ end_page = int(request.POST.get('end_page', -1))
|
|
|
+ max_tokens = int(request.POST.get('max_tokens', 2048))
|
|
|
+ if max_tokens == 0:
|
|
|
+ max_tokens = 2048
|
|
|
+
|
|
|
+ logger.info(f"开始处理文档 ID: {document_id}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ document = DocumentKbm.objects.get(id=document_id)
|
|
|
+ if int(document.run) in [1, 5]: # 1: 处理中, 5: 等待处理
|
|
|
+ logger.info(f"文档 {document_id} 已有队列")
|
|
|
+ return success("文档正在处理中或已经处理完成")
|
|
|
+
|
|
|
+ # 准备消息
|
|
|
+ message = {
|
|
|
+ 'document_id': document_id,
|
|
|
+ 'start_page': start_page,
|
|
|
+ 'end_page': end_page,
|
|
|
+ 'max_tokens': max_tokens
|
|
|
+ }
|
|
|
+
|
|
|
+ # 发送消息到队列
|
|
|
+ if KbmService.send_to_rabbitmq(settings.RABBITMQ_QUEUE_NAME, message):
|
|
|
+ # 更新文档状态为等待处理
|
|
|
+ document.run = 5 # 5表示等待处理
|
|
|
+ document.save()
|
|
|
+ logger.info(f"文档 {document_id} 状态已更新为等待处理")
|
|
|
+ return success("文档已添加到处理队列")
|
|
|
+ else:
|
|
|
+ document.run = 4
|
|
|
+ document.save()
|
|
|
+ return fail("添加文档到处理队列失败")
|
|
|
+
|
|
|
+ except DocumentKbm.DoesNotExist:
|
|
|
+ logger.error(f"文档 {document_id} 不存在")
|
|
|
+ document.run = 4
|
|
|
+ document.save()
|
|
|
+ return fail("文档不存在")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理文档 {document_id} 时出错: {str(e)}")
|
|
|
+ document.run = 4
|
|
|
+ document.save()
|
|
|
+ return fail("处理文档时出错")
|
|
|
+
|
|
|
+ semaphore = threading.Semaphore(4)
|
|
|
+ # @staticmethod
|
|
|
+ # def process_queue():
|
|
|
+ # logger.info("开始监测RabbitMQ队列")
|
|
|
+ # connection = pika.BlockingConnection(pika.ConnectionParameters(
|
|
|
+ # host=settings.RABBITMQ_HOST,
|
|
|
+ # port=settings.RABBITMQ_PORT,
|
|
|
+ # credentials=pika.PlainCredentials(
|
|
|
+ # settings.RABBITMQ_USER,
|
|
|
+ # settings.RABBITMQ_PASSWORD
|
|
|
+ # )
|
|
|
+ # ))
|
|
|
+ # channel = connection.channel()
|
|
|
+ # channel.queue_declare(queue=settings.RABBITMQ_QUEUE_NAME, durable=True)
|
|
|
+ #
|
|
|
+ # def callback(ch, method, properties, body):
|
|
|
+ # with KbmService.semaphore:
|
|
|
+ # try:
|
|
|
+ # job = json.loads(body)
|
|
|
+ # document_id = job['document_id']
|
|
|
+ # start_page = job['start_page']
|
|
|
+ # end_page = job['end_page']
|
|
|
+ # max_tokens = job['max_tokens']
|
|
|
+ #
|
|
|
+ # logger.info(f"开始执行解析文档 {document_id}")
|
|
|
+ # KbmService.async_analysis(document_id, start_page, end_page, max_tokens)
|
|
|
+ #
|
|
|
+ # # 处理成功,确认消息
|
|
|
+ # ch.basic_ack(delivery_tag=method.delivery_tag)
|
|
|
+ # except Exception as e:
|
|
|
+ # logger.error(f"处理队列消息时发生错误: {str(e)}")
|
|
|
+ # # 处理失败,拒绝消息并重新入队
|
|
|
+ # ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
|
|
|
+ #
|
|
|
+ # # 设置预取计数为4,与最大并发数相匹配
|
|
|
+ # channel.basic_qos(prefetch_count=4)
|
|
|
+ # channel.basic_consume(queue=settings.RABBITMQ_QUEUE_NAME, on_message_callback=callback)
|
|
|
+ #
|
|
|
+ # logger.info('等待队列消息。要退出请按 CTRL+C')
|
|
|
+ # channel.start_consuming()
|
|
|
+ connection = None
|
|
|
+ channel = None
|
|
|
+ should_stop = False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def check_and_process_queue():
|
|
|
+ KbmService.should_stop = False
|
|
|
+ while not KbmService.should_stop:
|
|
|
+ try:
|
|
|
+ if KbmService.queue_has_messages():
|
|
|
+ KbmService.process_queue()
|
|
|
+ else:
|
|
|
+ logger.info("队列为空,等待下一次检查...")
|
|
|
+ time.sleep(60) # 等待60秒后再次检查
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"检查队列时发生错误: {str(e)}")
|
|
|
+ time.sleep(60) # 发生错误时,等待60秒后重试
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def queue_has_messages():
|
|
|
+ try:
|
|
|
+ connection = KbmService.create_connection()
|
|
|
+ channel = connection.channel()
|
|
|
+ queue = channel.queue_declare(queue=settings.RABBITMQ_QUEUE_NAME, passive=True)
|
|
|
+ message_count = queue.method.message_count
|
|
|
+ connection.close()
|
|
|
+ return message_count > 0
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"检查队列消息数量时发生错误: {str(e)}")
|
|
|
+ return False
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_connection():
|
|
|
+ return pika.BlockingConnection(pika.ConnectionParameters(
|
|
|
+ host=settings.RABBITMQ_HOST,
|
|
|
+ port=settings.RABBITMQ_PORT,
|
|
|
+ credentials=pika.PlainCredentials(
|
|
|
+ settings.RABBITMQ_USER,
|
|
|
+ settings.RABBITMQ_PASSWORD
|
|
|
+ )
|
|
|
+ ))
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def process_queue():
|
|
|
+ logger.info("队列中有消息,开始处理...")
|
|
|
+ KbmService.connection = KbmService.create_connection()
|
|
|
+ KbmService.channel = KbmService.connection.channel()
|
|
|
+ KbmService.channel.queue_declare(queue=settings.RABBITMQ_QUEUE_NAME, durable=True)
|
|
|
+ KbmService.channel.basic_qos(prefetch_count=4)
|
|
|
+ KbmService.channel.basic_consume(queue=settings.RABBITMQ_QUEUE_NAME, on_message_callback=KbmService.callback)
|
|
|
+
|
|
|
+ try:
|
|
|
+ KbmService.channel.start_consuming()
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ KbmService.should_stop = True
|
|
|
+ finally:
|
|
|
+ KbmService.close_connection()
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def callback(ch, method, properties, body):
|
|
|
+ with KbmService.semaphore:
|
|
|
+ try:
|
|
|
+ job = json.loads(body)
|
|
|
+ document_id = job['document_id']
|
|
|
+ start_page = job['start_page']
|
|
|
+ end_page = job['end_page']
|
|
|
+ max_tokens = job['max_tokens']
|
|
|
+
|
|
|
+ logger.info(f"开始执行解析文档 {document_id}")
|
|
|
+ KbmService.async_analysis(document_id, start_page, end_page, max_tokens)
|
|
|
+ ch.basic_ack(delivery_tag=method.delivery_tag)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"处理队列消息时发生错误: {str(e)}")
|
|
|
+ ch.basic_nack(delivery_tag=method.delivery_tag, requeue=True)
|
|
|
+
|
|
|
+ # 检查是否还有更多消息
|
|
|
+ if not KbmService.queue_has_messages():
|
|
|
+ logger.info("队列处理完毕,停止消费...")
|
|
|
+ ch.stop_consuming()
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def close_connection():
|
|
|
+ if KbmService.channel:
|
|
|
+ try:
|
|
|
+ KbmService.channel.close()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ if KbmService.connection:
|
|
|
+ try:
|
|
|
+ KbmService.connection.close()
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ KbmService.channel = None
|
|
|
+ KbmService.connection = None
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def stop_service():
|
|
|
+ KbmService.should_stop = True
|
|
|
+ if KbmService.channel:
|
|
|
+ KbmService.channel.stop_consuming()
|
|
|
+ KbmService.close_connection()
|
|
|
+ @staticmethod
|
|
|
+ def get_embedding_excel(text, target_dim=768):
|
|
|
+ """
|
|
|
+ 获取文本的嵌入向量
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ if not text or not text.strip():
|
|
|
+ logging.warning("Empty text provided for embedding. Returning zero vector.")
|
|
|
+ return np.zeros(target_dim).tolist()
|
|
|
+
|
|
|
+ # 确保文本被正确编码
|
|
|
+ encoded_text = text.encode('utf-8').decode('utf-8')
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": "nomic-embed-text:latest",
|
|
|
+ "prompt": encoded_text
|
|
|
+ }
|
|
|
+ headers = {"Content-Type": "application/json"}
|
|
|
+
|
|
|
+ response = requests.post(KbmService.API_URL, json=payload, headers=headers)
|
|
|
+ logger.info(f"response::::{response}")
|
|
|
+ response.raise_for_status()
|
|
|
+ embedding_data = response.json()
|
|
|
+
|
|
|
+ if 'embedding' not in embedding_data:
|
|
|
+ raise ValueError(f"API 响应中没有找到嵌入向量. 响应内容: {embedding_data}")
|
|
|
+
|
|
|
+ embedding = embedding_data['embedding']
|
|
|
+ original_embedding = np.array(embedding)
|
|
|
+
|
|
|
+ if len(original_embedding) == target_dim:
|
|
|
+ return original_embedding.tolist()
|
|
|
+
|
|
|
+ # 如果原始维度不等于目标维度,进行插值
|
|
|
+ original_indices = np.arange(len(original_embedding))
|
|
|
+ new_indices = np.linspace(0, len(original_embedding) - 1, target_dim)
|
|
|
+ f = interpolate.interp1d(original_indices, original_embedding)
|
|
|
+ extended_embedding = f(new_indices)
|
|
|
+
|
|
|
+ return extended_embedding.tolist()
|
|
|
+
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
+ logging.error(f"API 请求错误: {str(e)}")
|
|
|
+ raise
|
|
|
+ except ValueError as e:
|
|
|
+ logging.error(f"值错误: {str(e)}")
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"获取文本嵌入时发生意外错误: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_embedding_pdf(cls, text, target_dim=768, max_retries=3, backoff_factor=0.3):
|
|
|
+ """
|
|
|
+ 获取文本的嵌入向量,并填充或截断到目标维度,包含重试机制
|
|
|
+ """
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ try:
|
|
|
+ payload = {
|
|
|
+ "model": "nomic-embed-text:latest",
|
|
|
+ "prompt": text
|
|
|
+ }
|
|
|
+ headers = {
|
|
|
+ "Content-Type": "application/json"
|
|
|
+ }
|
|
|
+
|
|
|
+ response = requests.post(cls.API_URL, json=payload, headers=headers, timeout=30)
|
|
|
+ sleep(0.3)
|
|
|
+ response.raise_for_status()
|
|
|
+ result = response.json()
|
|
|
+
|
|
|
+ embedding = result.get('embedding')
|
|
|
+
|
|
|
+ if embedding is None:
|
|
|
+ raise ValueError("API 响应中没有找到嵌入向量")
|
|
|
+
|
|
|
+ embedding_array = np.array(embedding)
|
|
|
+ current_dim = embedding_array.shape[0]
|
|
|
+
|
|
|
+ if current_dim < target_dim:
|
|
|
+ padded_embedding = np.pad(embedding_array, (0, target_dim - current_dim), 'constant')
|
|
|
+ logging.info(f"向量已从 {current_dim} 维填充到 {target_dim} 维")
|
|
|
+ return padded_embedding
|
|
|
+ elif current_dim > target_dim:
|
|
|
+ truncated_embedding = embedding_array[:target_dim]
|
|
|
+ logging.info(f"向量已从 {current_dim} 维截断到 {target_dim} 维")
|
|
|
+ return truncated_embedding
|
|
|
+ else:
|
|
|
+ return embedding_array
|
|
|
+
|
|
|
+ except RequestException as e:
|
|
|
+ logging.error(f"API 请求错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
|
|
+ if attempt == max_retries - 1:
|
|
|
+ raise
|
|
|
+ time.sleep(backoff_factor * (2 ** attempt))
|
|
|
+ except ValueError as e:
|
|
|
+ logging.error(f"解析响应错误: {e}")
|
|
|
+ raise
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"获取文本嵌入时发生未知错误: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ raise Exception("达到最大重试次数,无法获取嵌入")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def split_text_by_semantic(text, max_tokens, bucket_name, similarity_threshold=0.5, batch_size=1000):
|
|
|
+ logging.info("开始分割文本并保存到向量数据库")
|
|
|
+
|
|
|
+ try:
|
|
|
+ #object1 object1为后续可能添加的字段 因为无法直接修改名称 备用
|
|
|
+ source = "知识库"
|
|
|
+ object1 ="some_object1"
|
|
|
+ object2 ="some_object2"
|
|
|
+ # 连接到Milvus
|
|
|
+ connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT,user=MILVUS_USER,password=MILVUS_PASSWORD)
|
|
|
+ collection_name = f"{bucket_name}"
|
|
|
+ collection = KbmService._get_or_create_collection(collection_name)
|
|
|
+
|
|
|
+ sentences = KbmService._split_sentences(text)
|
|
|
+ if not sentences:
|
|
|
+ logging.warning("没有找到有效的句子,将文本按最大令牌数分割")
|
|
|
+ sentences = [text[i:i + max_tokens] for i in range(0, len(text), max_tokens)]
|
|
|
+
|
|
|
+ chunks = []
|
|
|
+ current_chunk = sentences[0]
|
|
|
+ current_embedding = KbmService.get_embedding_pdf(current_chunk, target_dim=VECTOR_DIMENSION)
|
|
|
+ batch_data = []
|
|
|
+
|
|
|
+ for sentence in sentences[1:]:
|
|
|
+ sentence_embedding = KbmService.get_embedding_pdf(sentence, target_dim=VECTOR_DIMENSION)
|
|
|
+ similarity = 1 - cosine(current_embedding, sentence_embedding)
|
|
|
+
|
|
|
+ if len(current_chunk) + len(sentence) <= max_tokens and similarity >= similarity_threshold:
|
|
|
+ current_chunk += sentence
|
|
|
+ current_embedding = (current_embedding + sentence_embedding) / 2
|
|
|
+ else:
|
|
|
+ batch_data.append((current_chunk, current_embedding))
|
|
|
+ if len(batch_data) >= batch_size:
|
|
|
+ ids = KbmService._insert_batch(collection, batch_data,source,object1,object2)
|
|
|
+ sleep(1)
|
|
|
+ logger.info("减少milvus压力睡眠1秒")
|
|
|
+ if ids is not None:
|
|
|
+ chunks.extend([{'content': chunk, 'milvus_id': id} for (chunk, _), id in zip(batch_data, ids)])
|
|
|
+ else:
|
|
|
+ logging.error("向 Milvus 插入批量数据失败,这批数据将被跳过")
|
|
|
+ batch_data = []
|
|
|
+ current_chunk = sentence
|
|
|
+ current_embedding = sentence_embedding
|
|
|
+
|
|
|
+ # 处理最后一个chunk和剩余的batch数据
|
|
|
+ if current_chunk:
|
|
|
+ batch_data.append((current_chunk, current_embedding))
|
|
|
+ if batch_data:
|
|
|
+ ids = KbmService._insert_batch(collection, batch_data,source,object1,object2)
|
|
|
+ sleep(1)
|
|
|
+ logger.info("减少milvus压力睡眠1秒")
|
|
|
+ if ids is not None:
|
|
|
+ chunks.extend([{'content': chunk, 'milvus_id': id} for (chunk, _), id in zip(batch_data, ids)])
|
|
|
+ else:
|
|
|
+ logging.error("向 Milvus 插入批量数据失败,这批数据将被跳过")
|
|
|
+
|
|
|
+ KbmService._create_index_and_load(collection)
|
|
|
+
|
|
|
+ logging.info(f"成功将{len(chunks)}个文本块分割并保存到Milvus")
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"处理文本时发生错误: {str(e)}")
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ connections.disconnect("default")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_or_create_collection(collection_name):
|
|
|
+ if not utility.has_collection(collection_name):
|
|
|
+ fields = [
|
|
|
+ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
|
+ FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=65000),
|
|
|
+ FieldSchema(name="object1", dtype=DataType.VARCHAR, max_length=65000),
|
|
|
+ FieldSchema(name="object2", dtype=DataType.VARCHAR, max_length=65000),
|
|
|
+ FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65000),
|
|
|
+ FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=VECTOR_DIMENSION)
|
|
|
+ ]
|
|
|
+ schema = CollectionSchema(fields, "Semantic text chunks collection")
|
|
|
+ return Collection(name=collection_name, schema=schema)
|
|
|
+ return Collection(name=collection_name)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _split_sentences(text):
|
|
|
+ sentences = re.split('([。!?])', text)
|
|
|
+ sentences = [''.join(i) for i in zip(sentences[0::2], sentences[1::2] + [''])]
|
|
|
+ return [s.strip() for s in sentences if s.strip()]
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _insert_batch(collection, batch_data, source, object1, object2):
|
|
|
+ try:
|
|
|
+ entities = [
|
|
|
+ [source] * len(batch_data), # source
|
|
|
+ [object1]* len(batch_data), # object1
|
|
|
+ [object2]* len(batch_data), # object2
|
|
|
+ [chunk for chunk, _ in batch_data],
|
|
|
+ [embedding.tolist() for _, embedding in batch_data]
|
|
|
+ ]
|
|
|
+ # 插入数据并获取插入操作的结果
|
|
|
+ insert_result = collection.insert(entities)
|
|
|
+
|
|
|
+ # 获取插入的 ID
|
|
|
+ inserted_ids = insert_result.primary_keys
|
|
|
+ return inserted_ids
|
|
|
+
|
|
|
+ logging.info(f"成功插入{len(batch_data)}个文本块到Milvus")
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"批量插入数据时发生错误: {str(e)}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _create_index_and_load(collection):
|
|
|
+ index_params = {
|
|
|
+ "index_type": "IVF_FLAT",
|
|
|
+ "metric_type": "L2",
|
|
|
+ "params": {"nlist": 768}
|
|
|
+ }
|
|
|
+ collection.create_index("embedding", index_params)
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def async_analysis(document_id, start_page, end_page, max_tokens):
|
|
|
+ start_time = time.time()
|
|
|
+ excel_status = 1
|
|
|
+ try:
|
|
|
+ logger.info(f"开始处理文档 {document_id}")
|
|
|
+
|
|
|
+ # 更新文档状态为处理中
|
|
|
+ DocumentKbm.objects.filter(id=document_id).update(run=1)
|
|
|
+
|
|
|
+ document = get_object_or_404(DocumentKbm, id=document_id)
|
|
|
+ object_name = document.location
|
|
|
+ file_extension = object_name.split('.')[-1].lower()
|
|
|
+
|
|
|
+ knowledgebase = get_object_or_404(Knowledgebase, id=document.kb_id)
|
|
|
+ bucket_name = knowledgebase.location
|
|
|
+
|
|
|
+ KbmService.clearPreviousData(document_id, bucket_name)
|
|
|
+
|
|
|
+ response = minio_client.get_object(bucket_name, object_name)
|
|
|
+ file_content = BytesIO(response.read())
|
|
|
+
|
|
|
+ if file_extension in ['xls', 'xlsx']:
|
|
|
+ result, excel_status = KbmService.process_excel(file_content, document_id, max_tokens, bucket_name)
|
|
|
+ elif file_extension == 'pdf':
|
|
|
+ result = KbmService.process_pdf(file_content, document_id, max_tokens,bucket_name)
|
|
|
+ # elif file_extension == 'txt':
|
|
|
+ # result = KbmService.process_txt(file_content, document_id, max_tokens, bucket_name)
|
|
|
+ elif file_extension == 'md':
|
|
|
+ result = KbmService.process_markdown(file_content, document_id, max_tokens, bucket_name)
|
|
|
+ elif file_extension in ['doc', 'docx']:
|
|
|
+ pdf_content = KbmService.convert_doc_to_pdf(file_content)
|
|
|
+ result = KbmService.process_pdf(pdf_content, document_id, max_tokens, bucket_name)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unsupported file type: {file_extension}")
|
|
|
+
|
|
|
+ KbmService.saveTask(document_id, len(result))
|
|
|
+ end_time = time.time()
|
|
|
+ execution_time = round(end_time - start_time, 2)
|
|
|
+ KbmService.updateDocument(max_tokens, len(result), document_id, execution_time)
|
|
|
+
|
|
|
+ # 更新文档状态
|
|
|
+ if excel_status == 6:
|
|
|
+ DocumentKbm.objects.filter(id=document_id).update(run=6) # Excel 特殊情况
|
|
|
+ logger.info(f"文档 {document_id} 更新完成,状态设置为6(Excel特殊情况)")
|
|
|
+ else:
|
|
|
+ DocumentKbm.objects.filter(id=document_id).update(run=3) # 假设3表示成功状态
|
|
|
+ logger.info(f"文档 {document_id} 更新完成,状态设置为3(成功)")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Analysis failed for document {document_id}: {str(e)}")
|
|
|
+ logger.error("Exception traceback:")
|
|
|
+ traceback.print_exc()
|
|
|
+ # 更新文档状态为失败
|
|
|
+ DocumentKbm.objects.filter(id=document_id).update(run=4)
|
|
|
+
|
|
|
+ logger.info(f"文档 {document_id} 处理完成")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def convert_doc_to_pdf(file_content):
|
|
|
+ try:
|
|
|
+ # 创建临时文件
|
|
|
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.docx') as temp_input:
|
|
|
+ temp_input.write(file_content.getvalue())
|
|
|
+ temp_input_path = temp_input.name
|
|
|
+
|
|
|
+ temp_output_dir = tempfile.mkdtemp()
|
|
|
+
|
|
|
+ # 查找 LibreOffice 路径
|
|
|
+ libreoffice_path = KbmService.get_libreoffice_path()
|
|
|
+ if not libreoffice_path:
|
|
|
+ raise FileNotFoundError("找不到 LibreOffice 可执行文件")
|
|
|
+
|
|
|
+ # 转换为 PDF
|
|
|
+ pdf_path = KbmService.run_libreoffice_conversion(libreoffice_path, temp_input_path, temp_output_dir)
|
|
|
+
|
|
|
+ # 读取 PDF 内容
|
|
|
+ with open(pdf_path, 'rb') as pdf_file:
|
|
|
+ pdf_content = pdf_file.read()
|
|
|
+
|
|
|
+
|
|
|
+ # 读取并返回 PDF 内容
|
|
|
+ with open(pdf_path, 'rb') as pdf_file:
|
|
|
+ return BytesIO(pdf_file.read())
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"将文档转换为 PDF 时出错: {str(e)}", exc_info=True)
|
|
|
+ return BytesIO()
|
|
|
+
|
|
|
+ finally:
|
|
|
+ KbmService.cleanup_temp_files(temp_input_path, temp_output_dir)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_libreoffice_path():
|
|
|
+ system = platform.system()
|
|
|
+ if system == "Windows":
|
|
|
+ libreoffice_paths = [r"D:\Program Files\libreoffice\program\soffice.exe"]
|
|
|
+ return next((path for path in libreoffice_paths if os.path.exists(path)), None)
|
|
|
+ else: # Linux 或 macOS
|
|
|
+ for path in ['/usr/bin/libreoffice', '/usr/bin/soffice', '/opt/libreoffice/program/soffice']:
|
|
|
+ if os.path.exists(path):
|
|
|
+ return path
|
|
|
+ return shutil.which('libreoffice') or shutil.which('soffice')
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def run_libreoffice_conversion(libreoffice_path, input_path, output_dir):
|
|
|
+ cmd = [
|
|
|
+ libreoffice_path,
|
|
|
+ '--headless',
|
|
|
+ '--convert-to', 'pdf:writer_pdf_Export:{"PageSize":{"Width":21000,"Height":29700}}',
|
|
|
+ '--outdir', output_dir,
|
|
|
+ input_path
|
|
|
+ ]
|
|
|
+ try:
|
|
|
+ env = os.environ.copy()
|
|
|
+ env['HOME'] = '/mnt/ql_api/tmp' # 设置一个临时的 HOME 目录
|
|
|
+ env['LC_ALL'] = 'C' # 设置一个标准的语言环境
|
|
|
+ sleep(0.3)
|
|
|
+ result = subprocess.run(cmd, check=True, capture_output=True, text=True, timeout=60, env=env)
|
|
|
+ logging.info(f"LibreOffice 转换输出: {result.stdout}")
|
|
|
+
|
|
|
+ pdf_filename = os.path.splitext(os.path.basename(input_path))[0] + '.pdf'
|
|
|
+ pdf_path = os.path.join(output_dir, pdf_filename)
|
|
|
+
|
|
|
+ if not os.path.exists(pdf_path):
|
|
|
+ raise FileNotFoundError(f"PDF 文件未生成。输出目录内容: {os.listdir(output_dir)}")
|
|
|
+
|
|
|
+ # 使用 PyPDF2 检查页数
|
|
|
+ with open(pdf_path, 'rb') as pdf_file:
|
|
|
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
|
|
|
+ page_count = len(pdf_reader.pages)
|
|
|
+ logging.info(f"生成的 PDF 文件页数: {page_count}")
|
|
|
+
|
|
|
+ pdf_size = os.path.getsize(pdf_path)
|
|
|
+ if pdf_size < 1000:
|
|
|
+ logging.warning(f"生成的 PDF 文件大小异常小: {pdf_size} bytes")
|
|
|
+
|
|
|
+ return pdf_path
|
|
|
+ except subprocess.TimeoutExpired:
|
|
|
+ raise TimeoutError("LibreOffice 转换超时")
|
|
|
+ except subprocess.CalledProcessError as e:
|
|
|
+ raise RuntimeError(f"LibreOffice 转换失败: {e.output}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def cleanup_temp_files(temp_input_path, temp_output_dir):
|
|
|
+ if os.path.exists(temp_input_path):
|
|
|
+ os.remove(temp_input_path)
|
|
|
+ if os.path.exists(temp_output_dir):
|
|
|
+ shutil.rmtree(temp_output_dir)
|
|
|
+ @staticmethod
|
|
|
+ def process_excel(file_content, document_id, max_tokens, bucket_name):
|
|
|
+ result = []
|
|
|
+ collection_name = f"{bucket_name}"
|
|
|
+ status=1
|
|
|
+ # object1 object1为后续可能添加的字段 因为无法直接修改名称 备用
|
|
|
+ source = "知识库"
|
|
|
+ object1 = "some_object1"
|
|
|
+ object2 = "some_object2"
|
|
|
+
|
|
|
+ def warning_catcher(message, category, filename, lineno, file=None, line=None):
|
|
|
+ nonlocal status
|
|
|
+ if category == UserWarning:
|
|
|
+ if "File contains an invalid specification for 0" in str(message) or \
|
|
|
+ "Defined names for sheet index 0 cannot be located" in str(message):
|
|
|
+ status = 6
|
|
|
+ logger.error(f"Warning: {message}")
|
|
|
+
|
|
|
+ warnings.showwarning = warning_catcher
|
|
|
+ try:
|
|
|
+ excel_file = pd.ExcelFile(file_content)
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Error reading Excel file: {str(e)}")
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+ try:
|
|
|
+ connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT,user=MILVUS_USER,password=MILVUS_PASSWORD)
|
|
|
+ collection = KbmService._get_or_create_collection(collection_name)
|
|
|
+
|
|
|
+ for sheet_name in excel_file.sheet_names:
|
|
|
+ df = pd.read_excel(excel_file, sheet_name=sheet_name)
|
|
|
+ if df.empty:
|
|
|
+ logging.warning(f"Sheet '{sheet_name}' is empty. Skipping.")
|
|
|
+ continue
|
|
|
+
|
|
|
+ logging.info(f"Processing sheet '{sheet_name}' with shape {df.shape}")
|
|
|
+
|
|
|
+ markdown_content = KbmService._excel_to_markdown(df, sheet_name)
|
|
|
+ chunks = KbmService._split_markdown(markdown_content, max_tokens)
|
|
|
+
|
|
|
+ for chunk_number, chunk_content in enumerate(chunks, start=1):
|
|
|
+ try:
|
|
|
+ if not chunk_content.strip():
|
|
|
+ logging.warning(f"Empty chunk {chunk_number} in sheet '{sheet_name}'. Skipping.")
|
|
|
+ continue
|
|
|
+
|
|
|
+ embedding = KbmService.get_embedding_excel(chunk_content, target_dim=VECTOR_DIMENSION)
|
|
|
+
|
|
|
+ if isinstance(embedding, (list, np.ndarray)) and len(embedding) == VECTOR_DIMENSION:
|
|
|
+ sleep(1)
|
|
|
+ logger.info("减少milvus压力睡眠1秒")
|
|
|
+ milvus_id = KbmService._insert_data(collection, chunk_content, embedding,source,object1,object2)
|
|
|
+ else:
|
|
|
+ logging.error(f"Invalid embedding format for chunk {chunk_number} of sheet {sheet_name}.")
|
|
|
+ continue
|
|
|
+
|
|
|
+ KbmService.saveTaskSublist(
|
|
|
+ document_id=document_id,
|
|
|
+ name=f"sheet_{sheet_name}",
|
|
|
+ page_number=1,
|
|
|
+ chunk_number=chunk_number,
|
|
|
+ content=chunk_content,
|
|
|
+ milvus_id=milvus_id
|
|
|
+ )
|
|
|
+
|
|
|
+ result.append({
|
|
|
+ 'page_number': 1,
|
|
|
+ 'chunk_number': chunk_number,
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Error processing chunk {chunk_number} of sheet {sheet_name}: {str(e)}")
|
|
|
+ logging.error(f"Chunk content: {chunk_content}")
|
|
|
+ logging.exception("Detailed error information:")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error processing Excel file: {str(e)}")
|
|
|
+ logger.error("Detailed error information:")
|
|
|
+ import traceback
|
|
|
+ logger.error(traceback.format_exc())
|
|
|
+ status = 6 # 设置状态为6,表示处理出错
|
|
|
+ raise
|
|
|
+
|
|
|
+ finally:
|
|
|
+ connections.disconnect("default")
|
|
|
+
|
|
|
+ return result, status
|
|
|
+ #excel转markdown
|
|
|
+ @staticmethod
|
|
|
+ def _excel_to_markdown(df, sheet_name):
|
|
|
+ if df.empty:
|
|
|
+ return f"# {sheet_name}\n\n表格为空"
|
|
|
+
|
|
|
+ headers = df.columns.tolist()
|
|
|
+ data = df.values.tolist()
|
|
|
+
|
|
|
+ # 将所有数据转换为字符串
|
|
|
+ data = [[str(cell) for cell in row] for row in data]
|
|
|
+
|
|
|
+ markdown = f"# {sheet_name}\n\n"
|
|
|
+ markdown += tabulate(data, headers=headers, tablefmt="pipe", showindex=False)
|
|
|
+ return markdown
|
|
|
+ #excel分割策略
|
|
|
+ @staticmethod
|
|
|
+ def _split_json(json_str, max_tokens):
|
|
|
+ # 简单的分割策略,可以根据需要优化
|
|
|
+ data = json.loads(json_str)
|
|
|
+ chunks = []
|
|
|
+ current_chunk = []
|
|
|
+ current_size = 0
|
|
|
+
|
|
|
+ for item in data:
|
|
|
+ item_str = json.dumps(item)
|
|
|
+ item_size = len(item_str)
|
|
|
+ if current_size + item_size > max_tokens and current_chunk:
|
|
|
+ chunks.append(json.dumps(current_chunk))
|
|
|
+ current_chunk = []
|
|
|
+ current_size = 0
|
|
|
+ current_chunk.append(item)
|
|
|
+ current_size += item_size
|
|
|
+
|
|
|
+ if current_chunk:
|
|
|
+ chunks.append(json.dumps(current_chunk))
|
|
|
+
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _split_markdown(markdown_content, max_tokens):
|
|
|
+ chunks = []
|
|
|
+ current_chunk = ""
|
|
|
+ lines = markdown_content.split('\n')
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ if len(current_chunk) + len(line) + 1 > max_tokens:
|
|
|
+ if current_chunk:
|
|
|
+ chunks.append(current_chunk.strip())
|
|
|
+ current_chunk = line
|
|
|
+ else:
|
|
|
+ current_chunk += '\n' + line if current_chunk else line
|
|
|
+
|
|
|
+ if current_chunk:
|
|
|
+ chunks.append(current_chunk.strip())
|
|
|
+
|
|
|
+ return chunks
|
|
|
+ #milvus excel插入格式
|
|
|
+ @staticmethod
|
|
|
+ def _insert_data(collection, content, embedding,source,object1,object2):
|
|
|
+ try:
|
|
|
+ data = [
|
|
|
+ [source], # content field
|
|
|
+ [object1], # content field
|
|
|
+ [object2], # content field
|
|
|
+ [content], # content field
|
|
|
+ [embedding] # embedding field
|
|
|
+ ]
|
|
|
+ insert_result = collection.insert(data)
|
|
|
+ logging.info(f"Inserted 1 record into Milvus")
|
|
|
+ return insert_result.primary_keys[0] # 返回插入的 ID
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Error inserting data into Milvus: {str(e)}")
|
|
|
+ raise
|
|
|
+ #创建milvus索引
|
|
|
+ @staticmethod
|
|
|
+ def _create_index_if_not_exists(collection):
|
|
|
+ if not collection.has_index():
|
|
|
+ index_params = {
|
|
|
+ "index_type": "IVF_FLAT",
|
|
|
+ "metric_type": "L2",
|
|
|
+ "params": {"nlist": 768}
|
|
|
+ }
|
|
|
+ collection.create_index("embedding", index_params)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def process_pdf(file_content, document_id, max_tokens, bucket_name):
|
|
|
+ logger.info(f"开始解析pdf")
|
|
|
+ doc = fitz.open(stream=file_content, filetype="pdf")
|
|
|
+ total_pages = len(doc)
|
|
|
+
|
|
|
+ # 提取整个文档的文本
|
|
|
+ full_text = ""
|
|
|
+ image_texts = []
|
|
|
+ page_images = []
|
|
|
+
|
|
|
+ for page_num in range(total_pages):
|
|
|
+ page = doc[page_num]
|
|
|
+ full_text += page.get_text()
|
|
|
+
|
|
|
+ # OCR识别图片
|
|
|
+ for img in page.get_images():
|
|
|
+ xref = img[0]
|
|
|
+ base_image = doc.extract_image(xref)
|
|
|
+ image_data = base_image["image"]
|
|
|
+ logger.info("开始ocr")
|
|
|
+ image_text = KbmService.extract_text_from_image(image_data)
|
|
|
+ sleep(1)
|
|
|
+ logger.info("识别完成,防止调用频繁睡眠1秒")
|
|
|
+ if image_text:
|
|
|
+ image_texts.append(image_text)
|
|
|
+
|
|
|
+ # 渲染页面为图像并保存
|
|
|
+ page_image = KbmService.render_page_to_image(page)
|
|
|
+ image_name = KbmService.save_image_to_minio(page_image, bucket_name)
|
|
|
+ page_images.append((page_num + 1, image_name))
|
|
|
+ # 将图片文本添加到全文中
|
|
|
+ if image_texts:
|
|
|
+ full_text += " ".join(image_texts)
|
|
|
+
|
|
|
+ # 对整个文本进行语义分割
|
|
|
+ text_chunks = KbmService.split_text_by_semantic(full_text, max_tokens, bucket_name)
|
|
|
+ logger.info("分割完成")
|
|
|
+
|
|
|
+ result = []
|
|
|
+ for i, chunk in enumerate(text_chunks):
|
|
|
+ # 为每个chunk分配一个页面图像
|
|
|
+ page_number, image_name = page_images[min(i, len(page_images) - 1)]
|
|
|
+
|
|
|
+ KbmService.saveTaskSublist(
|
|
|
+ document_id=document_id,
|
|
|
+ name=image_name,
|
|
|
+ page_number=page_number,
|
|
|
+ chunk_number=i + 1,
|
|
|
+ content=chunk['content'],
|
|
|
+ milvus_id=chunk['milvus_id']
|
|
|
+ )
|
|
|
+ result.append({
|
|
|
+ 'page_number': page_number,
|
|
|
+ 'chunk_number': i + 1,
|
|
|
+ })
|
|
|
+
|
|
|
+ doc.close()
|
|
|
+ logger.info("解析结束")
|
|
|
+ return result
|
|
|
+ # @staticmethod
|
|
|
+ # def process_pdf(file_content, document_id, start_page, end_page, max_tokens, bucket_name):
|
|
|
+ # print(f"开始解析pdf:::: {str(file_content)}")
|
|
|
+ # doc = fitz.open(stream=file_content, filetype="pdf")
|
|
|
+ # total_pages = len(doc)
|
|
|
+ #
|
|
|
+ # start_page = max(1, start_page) - 1
|
|
|
+ # end_page = min(total_pages, end_page if end_page > 0 else total_pages)
|
|
|
+ # result = []
|
|
|
+ # for page_num in range(start_page, end_page):
|
|
|
+ # page = doc[page_num]
|
|
|
+ # text = page.get_text()
|
|
|
+ #
|
|
|
+ # image_texts = []
|
|
|
+ # for img in page.get_images():
|
|
|
+ # xref = img[0]
|
|
|
+ # base_image = doc.extract_image(xref)
|
|
|
+ # image_data = base_image["image"]
|
|
|
+ # #OCR识别图片
|
|
|
+ # print("开始ocr")
|
|
|
+ # image_text = KbmService.extract_text_from_image(image_data)
|
|
|
+ # sleep(1)
|
|
|
+ # print("识别完成,防止调用频繁睡眠1秒")
|
|
|
+ # if image_text:
|
|
|
+ # image_texts.append(image_text)
|
|
|
+ # if image_texts:
|
|
|
+ # text += "".join(image_texts)
|
|
|
+ #
|
|
|
+ # # 使用语义分割替代简单的文本分割
|
|
|
+ # text_chunks = KbmService.split_text_by_semantic(text, max_tokens,bucket_name)
|
|
|
+ # print("分割完成")
|
|
|
+ # page_image = KbmService.render_page_to_image(page)
|
|
|
+ # print("转化图片")
|
|
|
+ # image_name = KbmService.save_image_to_minio(page_image, bucket_name)
|
|
|
+ # print("上传minio")
|
|
|
+ # for i, chunk in enumerate(text_chunks):
|
|
|
+ # KbmService.saveTaskSublist(
|
|
|
+ # document_id=document_id,
|
|
|
+ # name=image_name,
|
|
|
+ # page_number=page_num + 1,
|
|
|
+ # chunk_number=i + 1,
|
|
|
+ # content=chunk['content'],
|
|
|
+ # milvus_id=chunk['milvus_id']
|
|
|
+ # )
|
|
|
+ # result.append({
|
|
|
+ # 'page_number': page_num + 1,
|
|
|
+ # 'chunk_number': i + 1,
|
|
|
+ # })
|
|
|
+ #
|
|
|
+ # doc.close()
|
|
|
+ # print("解析结束")
|
|
|
+ # return result
|
|
|
+ #解析markdown
|
|
|
+ @staticmethod
|
|
|
+ def process_markdown(file_content, document_id, max_tokens, bucket_name):
|
|
|
+ logging.info(f"开始解析 Markdown,document_id: {document_id}, max_tokens: {max_tokens}")
|
|
|
+ try:
|
|
|
+ # 检测文件编码
|
|
|
+ raw_content = file_content.read()
|
|
|
+ detected = chardet.detect(raw_content)
|
|
|
+ encoding = detected['encoding']
|
|
|
+ logging.info(f"检测到的文件编码: {encoding}")
|
|
|
+
|
|
|
+ # 解码文件内容
|
|
|
+ text = raw_content.decode(encoding)
|
|
|
+ logging.info(f"Markdown 文件总字符数: {len(text)}")
|
|
|
+ logging.debug(f"Markdown 文件前100个字符: {text[:100]}")
|
|
|
+
|
|
|
+ # 将 Markdown 转换为 HTML
|
|
|
+ html = markdown.markdown(text)
|
|
|
+
|
|
|
+ # 使用 BeautifulSoup 提取纯文本
|
|
|
+ soup = BeautifulSoup(html, 'html.parser')
|
|
|
+ plain_text = soup.get_text()
|
|
|
+
|
|
|
+ logging.info(f"提取的纯文本总字符数: {len(plain_text)}")
|
|
|
+ logging.debug(f"提取的纯文本前100个字符: {plain_text[:100]}")
|
|
|
+
|
|
|
+ if not plain_text.strip():
|
|
|
+ logging.warning("Markdown 文件内容为空")
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 使用 split_text_by_semantic 方法分割文本
|
|
|
+ text_chunks = KbmService.split_text_by_semantic(plain_text, max_tokens, bucket_name)
|
|
|
+ logging.info(f"分割后的文本块数: {len(text_chunks)}")
|
|
|
+
|
|
|
+ result = []
|
|
|
+ for i, chunk in enumerate(text_chunks, 1):
|
|
|
+ KbmService.saveTaskSublist(
|
|
|
+ document_id=document_id,
|
|
|
+ name="markdown_content",
|
|
|
+ page_number=1,
|
|
|
+ chunk_number=i,
|
|
|
+ content=chunk['content'],
|
|
|
+ milvus_id=chunk['milvus_id']
|
|
|
+ )
|
|
|
+ result.append({
|
|
|
+ 'page_number': 1,
|
|
|
+ 'chunk_number': i,
|
|
|
+ })
|
|
|
+
|
|
|
+ logging.info(f"Markdown 处理完成,总共生成 {len(result)} 个文本块")
|
|
|
+ return result
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"处理 Markdown 时发生错误: {str(e)}")
|
|
|
+ logging.exception("详细错误信息:")
|
|
|
+ return []
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def split_text(text, max_tokens):
|
|
|
+ words = text.split()
|
|
|
+ chunks = []
|
|
|
+ current_chunk = []
|
|
|
+ current_token_count = 0
|
|
|
+
|
|
|
+ for word in words:
|
|
|
+ word_tokens = KbmService.estimate_tokens(word)
|
|
|
+ if current_token_count + word_tokens > max_tokens and current_chunk:
|
|
|
+ chunks.append(' '.join(current_chunk))
|
|
|
+ current_chunk = []
|
|
|
+ current_token_count = 0
|
|
|
+ current_chunk.append(word)
|
|
|
+ current_token_count += word_tokens
|
|
|
+
|
|
|
+ if current_chunk:
|
|
|
+ chunks.append(' '.join(current_chunk))
|
|
|
+
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def estimate_tokens(text):
|
|
|
+ return len(re.findall(r'\w+', text)) * 1.3
|
|
|
+ #OCR图片识别
|
|
|
+ ocr = PaddleOCR(use_angle_cls=True, lang="ch", use_gpu=False, det_db_thresh=0.3, det_db_box_thresh=0.3)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def extract_text_from_image(image_data):
|
|
|
+ try:
|
|
|
+ if isinstance(image_data, BytesIO):
|
|
|
+ image_data = image_data.getvalue()
|
|
|
+
|
|
|
+ # 读取图像
|
|
|
+ nparr = np.frombuffer(image_data, np.uint8)
|
|
|
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
|
+
|
|
|
+ # 图像预处理
|
|
|
+ preprocessed = KbmService.preprocess_image(image)
|
|
|
+
|
|
|
+ # 使用PaddleOCR进行识别
|
|
|
+ result = KbmService.ocr.ocr(preprocessed, cls=True)
|
|
|
+
|
|
|
+ # 提取文本并去除所有空格
|
|
|
+ if result and isinstance(result[0], list):
|
|
|
+ text = ''.join(
|
|
|
+ [line[1][0].replace(' ', '') for line in result[0] if line[1][1] > 0.5]) # 只保留置信度大于0.5的结果,并去除空格
|
|
|
+ else:
|
|
|
+ text = ""
|
|
|
+
|
|
|
+ # 后处理
|
|
|
+ text = KbmService.post_process_text(text)
|
|
|
+
|
|
|
+ if text and len(text) > 2: # 假设有意义的文本至少有3个字符
|
|
|
+ logging.info(f"提取的文本长度: {len(text)}")
|
|
|
+ return text
|
|
|
+ else:
|
|
|
+ logging.info("提取的内容似乎是图像,而不是文本")
|
|
|
+ return "图片"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"从图像提取文本时出错: {str(e)}")
|
|
|
+ return "图片"
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def preprocess_image(image):
|
|
|
+ # 转换为灰度图像
|
|
|
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
|
+
|
|
|
+ # 自适应阈值处理
|
|
|
+ binary = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2)
|
|
|
+
|
|
|
+ # 对二值图像进行膨胀操作,使文字更粗
|
|
|
+ kernel = np.ones((2, 2), np.uint8)
|
|
|
+ dilated = cv2.dilate(binary, kernel, iterations=1)
|
|
|
+
|
|
|
+ return dilated
|
|
|
+
|
|
|
+ def post_process_text(text):
|
|
|
+ # 将连续的冒号或点替换为空格
|
|
|
+ text = re.sub(r'[:.]+', ' ', text)
|
|
|
+
|
|
|
+ # 保留中文字符、英文字母、数字、常用标点
|
|
|
+ text = re.sub(r'[^\u4e00-\u9fff\u3000-\u303fa-zA-Z0-9.,!?;:()"\'\s]', '', text)
|
|
|
+
|
|
|
+ # 删除连续的数字(3个或更多)
|
|
|
+ text = re.sub(r'\d{3,}', '', text)
|
|
|
+
|
|
|
+ # 处理多余的空白字符
|
|
|
+ text = re.sub(r'\s+', ' ', text).strip()
|
|
|
+
|
|
|
+ # 删除单独的数字,但保留章节编号和有意义的数字
|
|
|
+ text = re.sub(r'\b(?<![第章])\d+(?!\d)\b', '', text)
|
|
|
+
|
|
|
+ # 清理多余的空格
|
|
|
+ text = re.sub(r'\s+', ' ', text).strip()
|
|
|
+
|
|
|
+ return text
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def render_page_to_image(page, scale=2):
|
|
|
+ pix = page.get_pixmap(matrix=fitz.Matrix(scale, scale))
|
|
|
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
|
|
+ buffered = BytesIO()
|
|
|
+ img.save(buffered, format="PNG")
|
|
|
+ buffered.seek(0)
|
|
|
+ return buffered
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def save_image_to_minio(image_data, bucket_name):
|
|
|
+ image_name = f"page_image_{uuid.uuid4()}.png"
|
|
|
+ minio_client.put_object(bucket_name, image_name, image_data, length=image_data.getbuffer().nbytes)
|
|
|
+ return image_name
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def saveTask(document_id, total_chunks):
|
|
|
+ Task.objects.update_or_create(
|
|
|
+ doc_id=document_id, # 查找条件
|
|
|
+ defaults={'to_page': total_chunks} # 要更新或创建的字段
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def saveTaskSublist(document_id, name, page_number=None, chunk_number=None, content=None,milvus_id=None):
|
|
|
+ try:
|
|
|
+ # 确保 content 是 Unicode 字符串
|
|
|
+ if content is not None:
|
|
|
+ if isinstance(content, bytes):
|
|
|
+ content = content.decode('utf-8')
|
|
|
+ else:
|
|
|
+ content = str(content)
|
|
|
+
|
|
|
+ TaskSublist.objects.create(
|
|
|
+ doc_id=document_id,
|
|
|
+ name=name,
|
|
|
+ page_number=str(page_number) if page_number is not None else '0',
|
|
|
+ chunk_number=str(chunk_number) if chunk_number is not None else '0',
|
|
|
+ content=content,
|
|
|
+ milvus_id=milvus_id
|
|
|
+ )
|
|
|
+ logging.info(f"Successfully saved TaskSublist for document {document_id}, chunk {chunk_number}")
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Error saving TaskSublist: {str(e)}")
|
|
|
+ # 可以选择在这里重新抛出异常,或者进行其他错误处理
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def clearPreviousData(document_id,bucket_name):
|
|
|
+ try:
|
|
|
+ # 获取与文档相关的所有 TaskSublist,milvus 记录
|
|
|
+ task_sublists = TaskSublist.objects.filter(doc_id=document_id)
|
|
|
+ # 获取集合对象
|
|
|
+ connections.connect("default", host=MILVUS_HOST, port=MILVUS_PORT)
|
|
|
+
|
|
|
+ milvus_collection_exists = utility.has_collection(bucket_name)
|
|
|
+ if milvus_collection_exists:
|
|
|
+ collection = Collection(bucket_name)
|
|
|
+ # 连接到 Milvus
|
|
|
+
|
|
|
+ # 从 MinIO 中删除相关的图片
|
|
|
+ for task in task_sublists:
|
|
|
+ try:
|
|
|
+ # 执行删除minio
|
|
|
+ minio_client.remove_object(bucket_name, task.name)
|
|
|
+
|
|
|
+ # 执行删除milvus
|
|
|
+ if milvus_collection_exists and task.milvus_id:
|
|
|
+ expr = f'id in [{task.milvus_id}]'
|
|
|
+ collection.delete(expr)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error deleting object from MinIO: {e}")
|
|
|
+
|
|
|
+ # 从数据库中删除 TaskSublist 记录
|
|
|
+ task_sublists.delete()
|
|
|
+ except Exception as e:
|
|
|
+ logging.error(f"Error deleting object from MinIO: {str(e)}")
|
|
|
+ finally:
|
|
|
+ connections.disconnect("default")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def updateDocument(max_tokens, total_chunks, document_id,execution_time):
|
|
|
+ try:
|
|
|
+ # 检查是否存在相关的 TaskSublist
|
|
|
+ count = TaskSublist.objects.filter(doc_id=document_id).count()
|
|
|
+
|
|
|
+ # 根据 TaskSublist 的存在与否设置运行状态
|
|
|
+ progress_status = 1 if count > 0 else -1
|
|
|
+
|
|
|
+
|
|
|
+ # 更新 DocumentKbm 对象
|
|
|
+ updated = DocumentKbm.objects.filter(id=document_id).update(
|
|
|
+ token_num=max_tokens,
|
|
|
+ chunk_num=total_chunks,
|
|
|
+ progress=progress_status,
|
|
|
+ process_begin_at=timezone.now(),
|
|
|
+ process_duation= execution_time
|
|
|
+ )
|
|
|
+
|
|
|
+ if updated:
|
|
|
+ return True, f"Document {document_id} updated successfully."
|
|
|
+ else:
|
|
|
+ return False, f"Document {document_id} not found."
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # 如果发生任何错误,事务会自动回滚
|
|
|
+ return False, f"Error updating document: {str(e)}"
|
|
|
+ #异步调用
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def searchTaskInfo(request):
|
|
|
+ document_id = request.POST.get("document_id")
|
|
|
+ page = request.POST.get('page', 1)
|
|
|
+ page_size = request.POST.get('page_size', 10) # 每页显示的项目数,默认为10
|
|
|
+
|
|
|
+ taskSublists = TaskSublist.objects.filter(doc_id=document_id).order_by('id')
|
|
|
+
|
|
|
+ document = get_object_or_404(DocumentKbm, id=document_id)
|
|
|
+ location = document.location
|
|
|
+
|
|
|
+ knowledgebase = get_object_or_404(Knowledgebase, id=document.kb_id)
|
|
|
+ bucket_name = knowledgebase.location
|
|
|
+
|
|
|
+ documentUrl = minio_client.presigned_get_object(
|
|
|
+ bucket_name=bucket_name,
|
|
|
+ object_name=location,
|
|
|
+ expires=timedelta(days=1) # URL有效期为1天
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建分页器
|
|
|
+ paginator = Paginator(taskSublists, page_size)
|
|
|
+
|
|
|
+ try:
|
|
|
+ tasks_page = paginator.page(page)
|
|
|
+ except PageNotAnInteger:
|
|
|
+ # 如果页码不是整数,返回第一页
|
|
|
+ tasks_page = paginator.page(1)
|
|
|
+ except EmptyPage:
|
|
|
+ # 如果页码超出范围,返回最后一页
|
|
|
+ tasks_page = paginator.page(paginator.num_pages)
|
|
|
+
|
|
|
+ task_results = []
|
|
|
+ for task in tasks_page:
|
|
|
+ try:
|
|
|
+ # 生成MinIO对象的预签名URL
|
|
|
+ url = minio_client.presigned_get_object(
|
|
|
+ bucket_name=bucket_name,
|
|
|
+ object_name=task.name,
|
|
|
+ expires=timedelta(days=1) # URL有效期为1天
|
|
|
+ )
|
|
|
+
|
|
|
+ task_results.append({
|
|
|
+ 'id': task.id,
|
|
|
+ 'doc_id': task.doc_id,
|
|
|
+ 'name': task.name,
|
|
|
+ 'page_number': task.page_number,
|
|
|
+ 'chunk_number': task.chunk_number,
|
|
|
+ 'content': task.content,
|
|
|
+ 'url': url
|
|
|
+ })
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error generating URL for object {task.name}: {str(e)}")
|
|
|
+ # 如果生成URL失败,我们仍然添加其他信息,但URL为None
|
|
|
+ task_results.append({
|
|
|
+ 'id': task.id,
|
|
|
+ 'doc_id': task.doc_id,
|
|
|
+ 'name': task.name,
|
|
|
+ 'page_number': task.page_number,
|
|
|
+ 'chunk_number': task.chunk_number,
|
|
|
+ 'content': task.content,
|
|
|
+ 'url': None
|
|
|
+ })
|
|
|
+
|
|
|
+ # 创建包含 documentUrl 和分页信息的最终结果
|
|
|
+ result = {
|
|
|
+ 'documentUrl': documentUrl,
|
|
|
+ 'tasks': task_results,
|
|
|
+ 'pagination': {
|
|
|
+ 'current_page': tasks_page.number,
|
|
|
+ 'num_pages': paginator.num_pages,
|
|
|
+ 'per_page': page_size,
|
|
|
+ 'total_count': paginator.count,
|
|
|
+ 'has_next': tasks_page.has_next(),
|
|
|
+ 'has_previous': tasks_page.has_previous(),
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return success(result)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def deleteBucket(request):
|
|
|
+ bucket_id = request.POST.get("bucket_id")
|
|
|
+
|
|
|
+ if not bucket_id:
|
|
|
+ return fail("Bucket ID 为空")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 检查是否存在未删除的文档
|
|
|
+ active_docs_count = DocumentKbm.objects.filter(kb_id=bucket_id).exclude(status=4).count()
|
|
|
+
|
|
|
+ if active_docs_count > 0:
|
|
|
+ return fail(f"无法删除知识库,还有 {active_docs_count} 个未删除的文档")
|
|
|
+
|
|
|
+ # 如果没有未删除的文档,则更新知识库状态
|
|
|
+ updated_count = Knowledgebase.objects.filter(id=bucket_id).update(status=4,name=bucket_id, location=bucket_id)
|
|
|
+
|
|
|
+ if updated_count == 0:
|
|
|
+ return fail("指定的知识库不存在")
|
|
|
+
|
|
|
+ return success("知识库已成功删除")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return fail(f"删除知识库时发生错误: {str(e)}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def getRunStatus(request):
|
|
|
+ document_id = request.POST.get("document_id")
|
|
|
+ run = DocumentKbm.objects.filter(id = document_id).values("run").first()
|
|
|
+ return success(run)
|
|
|
+ @staticmethod
|
|
|
+ def batchAnalysis(request):
|
|
|
+ ids_str = request.POST.get("ids")
|
|
|
+ start_page = int(request.POST.get('start_page', 1))
|
|
|
+ end_page = int(request.POST.get('end_page', -1))
|
|
|
+ max_tokens = int(request.POST.get('max_tokens', 2048))
|
|
|
+ try:
|
|
|
+ # 尝试将字符串解析为 JSON 列表
|
|
|
+ ids = json.loads(ids_str)
|
|
|
+
|
|
|
+ if not isinstance(ids, list):
|
|
|
+ return fail("无效输入:'ids'应该是一个列表")
|
|
|
+ results = []
|
|
|
+ for document_id in ids:
|
|
|
+ sleep(0.1)
|
|
|
+ logger.info("缓解压力沉睡0.1秒")
|
|
|
+ # 为每个 document_id 创建一个新的请求对象
|
|
|
+ analysis_request = type('AnalysisRequest', (), {})()
|
|
|
+ analysis_request.POST = {
|
|
|
+ 'document_id': document_id,
|
|
|
+ 'start_page': start_page,
|
|
|
+ 'end_page': end_page,
|
|
|
+ 'max_tokens': max_tokens
|
|
|
+ }
|
|
|
+
|
|
|
+ # 调用 analysis 方法
|
|
|
+ response = KbmService.analysis(analysis_request)
|
|
|
+ message = response.get('message')
|
|
|
+
|
|
|
+ return success(message,"已添加到队列")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+
|
|
|
+ return fail(f"An error occurred: {str(e)}")
|
|
|
+
|
|
|
+ # 假设这是您支持的文件后缀名列表
|
|
|
+ SUPPORTED_SUFFIXES = [
|
|
|
+ 'txt', 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'md'
|
|
|
+ ]
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def getSuffixName(request):
|
|
|
+ try:
|
|
|
+ # 获取数据库中的所有不重复的 type 值
|
|
|
+ db_types = DocumentKbm.objects.values_list('type', flat=True).distinct()
|
|
|
+
|
|
|
+ # 将数据库中的类型转换为集合
|
|
|
+ db_types_set = set(db_types)
|
|
|
+
|
|
|
+ # 将 SUPPORTED_SUFFIXES 转换为集合
|
|
|
+ supported_set = set(KbmService.SUPPORTED_SUFFIXES)
|
|
|
+
|
|
|
+ # 合并两个集合,自动去除重复项
|
|
|
+ combined_set = supported_set.union(db_types_set)
|
|
|
+
|
|
|
+ # 将结果转换回列表
|
|
|
+ combined_suffixes = list(combined_set)
|
|
|
+
|
|
|
+ # 对结果进行排序(可选)
|
|
|
+ combined_suffixes.sort()
|
|
|
+
|
|
|
+ return success(combined_suffixes)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ return fail(f"获取文件后缀名时发生错误: {str(e)}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def batchMove(request):
|
|
|
+ ids = json.loads(request.POST.get("ids"))
|
|
|
+ doc_type_id = request.POST.get("doc_type_id")
|
|
|
+ if not doc_type_id:
|
|
|
+ return fail("分类id为空")
|
|
|
+ if not ids:
|
|
|
+ return fail("未传出文件id")
|
|
|
+
|
|
|
+ type = KbmDocumentType.objects.filter(id=doc_type_id).exclude(status=4).first()
|
|
|
+ if not type:
|
|
|
+ return fail("当前分类不存在")
|
|
|
+ DocumentKbm.objects.filter(id__in=ids).update(doc_type_id=doc_type_id)
|
|
|
+
|
|
|
+ return success("批量移动成功")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def moveDocument(doc_id, doc_type_id):
|
|
|
+ # 这个方法可以保持不变,因为它已经是单次更新操作
|
|
|
+ return DocumentKbm.objects.filter(id=doc_id).update(doc_type_id=doc_type_id)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ @transaction.atomic
|
|
|
+ def updateKbm(request):
|
|
|
+ id = request.POST.get("id")
|
|
|
+ if not id:
|
|
|
+ return fail("id为空")
|
|
|
+ kmb = Knowledgebase.objects.filter(id=id).first()
|
|
|
+ name = request.POST.get("name")
|
|
|
+ if not name:
|
|
|
+ return fail("名称不能为空")
|
|
|
+ if kmb:
|
|
|
+ kmb.name = name
|
|
|
+ kmb.description = request.POST.get("description","")
|
|
|
+ kmb.save()
|
|
|
+ return success("修改成功")
|
|
|
+ else:
|
|
|
+ return fail("修改失败")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|