From 7f894745e92a96d8b889e15329b2a25ae3aa24e2 Mon Sep 17 00:00:00 2001 From: qhd <1766646056@qq.com> Date: Thu, 9 Apr 2026 09:11:43 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E5=A4=84=E7=90=86=E6=B5=81=E7=A8=8B=E5=92=8C=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/eino/a.go | 166 ------------- common/eino/b.go | 107 -------- common/eino/base_task.go | 49 ---- common/eino/c.go | 94 ------- common/eino/chat_model.go | 125 ++++++++++ common/eino/embedding_qwen.go | 273 -------------------- common/eino/priority_enum.go | 11 - common/eino/retriever.go | 120 ++++++--- common/eino/status_enum.go | 12 - common/eino/task_type.go | 14 -- common/gse/utils.go | 114 --------- common/task/base_task.go | 69 ++++++ common/task/consts.go | 30 +++ config.yml | 24 +- consts/public/table_name.go | 2 + controller/document.go | 4 +- controller/rag_query.go | 17 ++ dao/dataset_index.go | 1 + dao/document_chunk.go | 58 +++++ dao/keyword.go | 3 + dao/task.go | 58 +++++ go.mod | 9 +- go.sum | 17 +- main.go | 20 +- model/dto/document.go | 6 - model/dto/keyword.go | 13 +- model/dto/rag_query.go | 21 ++ model/dto/task.go | 65 +++++ model/entity/task.go | 66 +++++ service/document.go | 452 +++++++++++++++++++++++++--------- service/document_chunk.go | 49 ++-- service/rag_query.go | 52 ++++ service/task.go | 107 ++++++++ update.sql | 44 ++++ 34 files changed, 1216 insertions(+), 1056 deletions(-) delete mode 100644 common/eino/a.go delete mode 100644 common/eino/b.go delete mode 100644 common/eino/base_task.go delete mode 100644 common/eino/c.go create mode 100644 common/eino/chat_model.go delete mode 100644 common/eino/embedding_qwen.go delete mode 100644 common/eino/priority_enum.go delete mode 100644 common/eino/status_enum.go delete mode 100644 common/eino/task_type.go delete mode 100644 common/gse/utils.go create mode 100644 common/task/base_task.go create mode 100644 common/task/consts.go create mode 100644 controller/rag_query.go create mode 100644 dao/task.go create mode 100644 model/dto/rag_query.go create mode 100644 model/dto/task.go create mode 100644 model/entity/task.go create mode 100644 service/rag_query.go create mode 100644 service/task.go diff --git a/common/eino/a.go b/common/eino/a.go deleted file mode 100644 index 5b81744..0000000 --- a/common/eino/a.go +++ /dev/null @@ -1,166 +0,0 @@ -package eino - -import ( - "context" - "errors" - "fmt" - "io" - "log" - "os" - - "github.com/cloudwego/eino/components/prompt" - "github.com/cloudwego/eino/components/retriever" - "github.com/cloudwego/eino/schema" - "github.com/cloudwego/eino-ext/components/model/ark" -) - -func main() { - ctx := context.Background() - - // ========================================== - // 1. 初始化三大组件 - // ========================================== - // 1.1 向量检索(从知识库查客服知识) - ragRetriever := NewPGVectorRetriever() - - // 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题) - chatTpl := newCustomerServiceTemplate() - - // 1.3 大模型(ARK) - chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{ - APIKey: os.Getenv("ARK_API_KEY"), - Model: os.Getenv("ARK_MODEL_ID"), - }) - if err != nil { - log.Fatal(err) - } - - // ========================================== - // 2. 模拟会话:从 DB 读取历史对话 - // ========================================== - sessionHistory := []*schema.Message{ - {Role: schema.User, Content: "你们发什么快递?"}, - {Role: schema.Assistant, Content: "默认发中通快递"}, - {Role: schema.User, Content: "可以发顺丰吗?"}, - } - - // 当前用户问题 - userQuery := "那顺丰需要加钱吗?" - - // ========================================== - // 3. RAG 检索知识库 - // ========================================== - docs, err := ragRetriever.Retrieve(ctx, userQuery) - if err != nil { - log.Fatal(err) - } - - // 拼接参考知识 - knowledge := "" - for i, doc := range docs { - knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content) - } - - // ========================================== - // 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题 - // ========================================== - msgs, err := chatTpl.Format(ctx, map[string]any{ - "history": sessionHistory, - "knowledge": knowledge, - "question": userQuery, - }) - if err != nil { - log.Fatal(err) - } - - // ========================================== - // 5. 流式调用大模型生成客服回答 - // ========================================== - fmt.Println("\n=== 客服回复 ===") - stream, err := chatModel.Stream(ctx, msgs) - if err != nil { - log.Fatal(err) - } - - fullReply := make([]*schema.Message, 0, 100) - for { - chunk, err := stream.Recv() - if errors.Is(err, io.EOF) { - break - } - if err != nil { - log.Fatal(err) - } - fmt.Print(chunk.Content) - fullReply = append(fullReply, chunk) - } - - // ========================================== - // 6. 拼接完整回复,存入 DB 作为新历史 - // ========================================== - replyMsg, _ := schema.ConcatMessages(fullReply) - sessionHistory = append(sessionHistory, - &schema.Message{Role: schema.User, Content: userQuery}, - replyMsg, - ) - - // 接下来把 sessionHistory 存回你的 MySQL/Redis 即可 -} - -// ========================================== -// 本地客服提示词模板(不需要 MCP) -// ========================================== -func newCustomerServiceTemplate() prompt.ChatTemplate { - // 系统提示 + 多轮对话 + 知识库 + 用户问题 - return prompt.FromMessages(schema.Messages{ - { - Role: schema.System, - Content: `你是电商智能客服,语气友好简洁。 -请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。 - -参考知识: -{{.knowledge}}`, - }, - // 历史对话会自动渲染在这里 - {{range .history}}{{.}},{{end}}, - // 当前用户问题 - {Role: schema.User, Content: "{{.question}}"}, - }) -} - -// ========================================== -// PGVector 检索器(简化可直接用) -// ========================================== -type PGVectorRetriever struct { - topK int -} - -func NewPGVectorRetriever() retriever.Retriever { - return &PGVectorRetriever{topK: 3} -} - -func (r *PGVectorRetriever) Retrieve( - ctx context.Context, - query string, - opts ...retriever.Option, -) ([]*schema.Document, error) { - - options := retriever.GetCommonOptions(nil, opts...) - topK := r.topK - if options.TopK != nil { - topK = *options.TopK - } - - // ===== 这里替换成你真实的 PG 向量检索 SQL ===== - // 模拟知识库 - return []*schema.Document{ - { - ID: "1", - Content: "顺丰快递需要补10元运费差价", - }, - { - ID: "2", - Content: "订单满99元可免费升级顺丰", - }, - }, nil -} diff --git a/common/eino/b.go b/common/eino/b.go deleted file mode 100644 index a1f17dc..0000000 --- a/common/eino/b.go +++ /dev/null @@ -1,107 +0,0 @@ -package eino - -import ( - "context" - "fmt" - - "github.com/cloudwego/eino/schema" - "github.com/elastic/go-elasticsearch/v8" - - "github.com/cloudwego/eino-ext/components/indexer/es8" -) - -const ( - indexName = "eino_example" - fieldContent = "content" - fieldContentVector = "content_vector" - fieldExtraLocation = "location" - docExtraLocation = "location" -) - -func TestIndexer() { - ctx := context.Background() - - // 1. 创建 ES 客户端 - client, err := elasticsearch.NewClient(elasticsearch.Config{ - Addresses: []string{"http://localhost:9200"}, - }) - - if err != nil { - fmt.Printf("create client error: %v\n", err) - return - } - - // 2. 定义 Index Spec(选填:如果索引不存在,将自动创建) - indexSpec := &es8.IndexSpec{ - Settings: map[string]any{ - "number_of_shards": 1, - "number_of_replicas": 0, - }, - Mappings: map[string]any{ - "properties": map[string]any{ - fieldContentVector: map[string]any{ - "type": "dense_vector", - "dims": 1024, - "index": true, - "similarity": "l2_norm", - }, - }, - }, - } - - // 4. 准备文档 - // 文档通常包含 ID 和 Content - // 也可以包含额外的 Metadata 用于过滤或其他用途 - docs := []*schema.Document{ - { - ID: "1", - Content: "Eiffel Tower: Located in Paris, France.", - MetaData: map[string]any{ - docExtraLocation: "France", - }, - }, - { - ID: "2", - Content: "The Great Wall: Located in China.", - MetaData: map[string]any{ - docExtraLocation: "China", - }, - }, - } - - // 5. 创建 ES 索引器组件 - indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{ - Client: client, - Index: indexName, - IndexSpec: indexSpec, // 添加此项以启用自动索引创建 - BatchSize: 10, - // DocumentToFields 指定如何将文档字段映射到 ES 字段 - DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) { - return map[string]es8.FieldValue{ - fieldContent: { - Value: doc.Content, - EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段 - }, - fieldExtraLocation: { - // 额外的 metadata 字段 - Value: doc.MetaData[docExtraLocation], - }, - }, nil - }, - // 提供 embedding 组件用于向量化 - Embedding: EmbedderDashscope, - }) - - if err != nil { - fmt.Printf("create indexer error: %v\n", err) - return - } - - // 6. 索引文档 - ids, err := indexer.Store(ctx, docs) - if err != nil { - fmt.Printf("index error: %v\n", err) - return - } - fmt.Println("indexed ids:", ids) -} diff --git a/common/eino/base_task.go b/common/eino/base_task.go deleted file mode 100644 index b883341..0000000 --- a/common/eino/base_task.go +++ /dev/null @@ -1,49 +0,0 @@ -package eino - -import ( - "time" - - "gitea.com/red-future/common/beans" -) - -// BaseTask 任务基类 - MongoDB版本 -type BaseTask struct { - beans.MongoBaseDO `bson:",inline"` - // 任务信息 - TaskType TaskType `bson:"taskType" json:"taskType"` - Status TaskStatus `bson:"status" json:"status"` - Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"` - // 进度 - TotalItems int64 `bson:"totalItems" json:"totalItems"` - ProcessedItems int64 `bson:"processedItems" json:"processedItems"` - Progress float64 `bson:"progress" json:"progress"` - // 结果 - StartTime *time.Time `bson:"startTime" json:"startTime"` - EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"` - Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"` - SuccessCount int64 `bson:"successCount" json:"successCount"` - FailCount int64 `bson:"failCount" json:"failCount"` - // 其他 - Executor string `bson:"executor,omitempty" json:"executor,omitempty"` -} - -// SQLBaseTask 任务基类 - SQL版本 -type SQLBaseTask struct { - beans.SQLBaseDO - // 任务信息 - TaskType TaskType `json:"taskType"` - Status TaskStatus `json:"status"` - Priority TaskPriority `json:"priority,omitempty"` - // 进度 - TotalItems int64 `json:"totalItems"` - ProcessedItems int64 `json:"processedItems"` - Progress float64 `json:"progress"` - // 结果 - StartTime *time.Time `json:"startTime"` - EndTime *time.Time `json:"endTime,omitempty"` - Duration int64 `json:"duration,omitempty"` - SuccessCount int64 `json:"successCount"` - FailCount int64 `json:"failCount"` - // 其他 - Executor string `json:"executor,omitempty"` -} diff --git a/common/eino/c.go b/common/eino/c.go deleted file mode 100644 index ffc634a..0000000 --- a/common/eino/c.go +++ /dev/null @@ -1,94 +0,0 @@ -package eino - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/cloudwego/eino/schema" - "github.com/elastic/go-elasticsearch/v8" - "github.com/elastic/go-elasticsearch/v8/typedapi/types" - - "github.com/cloudwego/eino-ext/components/retriever/es8" - "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" -) - -func TestRetriever() { - ctx := context.Background() - - client, _ := elasticsearch.NewClient(elasticsearch.Config{ - Addresses: []string{"http://localhost:9200"}, - }) - - // 创建 retriever 组件 - retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{ - Client: client, - Index: indexName, - TopK: 5, - SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{ - QueryFieldName: fieldContent, - VectorFieldName: fieldContentVector, - Hybrid: false, - // RRF 仅在特定许可证下可用 - // 参见: https://www.elastic.co/subscriptions - RRF: false, - RRFRankConstant: nil, - RRFWindowSize: nil, - }), - ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) { - doc = &schema.Document{ - ID: *hit.Id_, - Content: "", - MetaData: map[string]any{}, - } - - var src map[string]any - if err = json.Unmarshal(hit.Source_, &src); err != nil { - return nil, err - } - - for field, val := range src { - switch field { - case fieldContent: - doc.Content = val.(string) - case fieldContentVector: - var v []float64 - for _, item := range val.([]interface{}) { - v = append(v, item.(float64)) - } - doc.WithDenseVector(v) - case fieldExtraLocation: - doc.MetaData[docExtraLocation] = val.(string) - } - } - - if hit.Score_ != nil { - doc.WithScore(float64(*hit.Score_)) - } - - return doc, nil - }, - Embedding: EmbedderDashscope, - }) - - // 不带过滤器的搜索 - docs, _ := retriever.Retrieve(ctx, "tourist attraction") - - // 带过滤器的搜索 - docs, _ = retriever.Retrieve(ctx, "tourist attraction", - es8.WithFilters([]types.Query{{ - Term: map[string]types.TermQuery{ - fieldExtraLocation: { - CaseInsensitive: of(true), - Value: "China", - }, - }, - }}), - ) - - fmt.Printf("retrieved docs: %+v\n", docs) -} - -func of[T any](v T) *T { - return &v -} diff --git a/common/eino/chat_model.go b/common/eino/chat_model.go new file mode 100644 index 0000000..4d9e4b7 --- /dev/null +++ b/common/eino/chat_model.go @@ -0,0 +1,125 @@ +package eino + +import ( + "context" + "errors" + "fmt" + "io" + + "github.com/cloudwego/eino-ext/components/model/qwen" + "github.com/cloudwego/eino/components/prompt" + "github.com/cloudwego/eino/schema" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/glog" + "github.com/gogf/gf/v2/util/gconv" +) + +var globalChatModel *qwen.ChatModel + +func init() { + ctx := context.Background() + + apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String() + model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String() + + var err error + globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{ + APIKey: apiKey, + Model: model, + BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1", + Temperature: gconv.PtrFloat32(0.7), // 客服最佳 + MaxTokens: gconv.PtrInt(1024), // 最长回答 + TopP: gconv.PtrFloat32(1.0), + }) + if err != nil { + glog.Errorf(ctx, "初始化大模型失败: %v", err) + } + + return +} + +// NewChatModel 只处理逻辑,不复用创建模型 +func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) { + // 1. 构建参考知识 + knowledge, sources := buildKnowledgeAndSources(docs) + + // 2. 构建提示词 + msgs, err := buildPromptMessages(ctx, knowledge, content) + if err != nil { + return + } + + // 3. 🔥 直接使用全局单例,不重复创建 + replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs) + + return +} + +// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源 +func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) { + var knowledge string + var sources []string + + for i, doc := range docs { + knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content) + + // 提取 document_id + if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 { + sources = append(sources, gconv.String(docID)) + } + } + return knowledge, sources +} + +// buildPromptMessages 构建提示词模板 +func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) { + promptTpl := prompt.FromMessages( + schema.FString, + &schema.Message{ + Role: schema.System, + // Content: `你是专业的客服助手,语气友好。 + //如果参考知识中有相关信息,请优先依据参考知识回答。 + //如果没有相关信息,就正常回答,不要说无法回答。 + // + //参考知识: + //{knowledge}`, + Content: `你是专业的客服助手,语气友好。 + 请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。 + + 参考知识: + {knowledge}`, + }, + &schema.Message{ + Role: schema.User, + Content: "{question}", + }, + ) + + return promptTpl.Format(ctx, map[string]any{ + "knowledge": knowledge, + "question": question, + }) +} + +// streamGenerateAnswer 流式生成 +func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) { + + sr, err := chatModel.Stream(ctx, msgs) + if err != nil { + return nil, fmt.Errorf("stream failed: %w", err) + } + + var chunks []*schema.Message + for { + chunk, err := sr.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("stream recv failed: %w", err) + } + chunks = append(chunks, chunk) + } + + return schema.ConcatMessages(chunks) +} diff --git a/common/eino/embedding_qwen.go b/common/eino/embedding_qwen.go deleted file mode 100644 index 9496874..0000000 --- a/common/eino/embedding_qwen.go +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Copyright 2024 Red Future Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package eino - -import ( - "context" - "fmt" - "net/http" - "time" - - "github.com/cloudwego/eino/callbacks" - "github.com/cloudwego/eino/components" - "github.com/cloudwego/eino/components/embedding" - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/net/gclient" - "github.com/gogf/gf/v2/util/gconv" -) - -var ( - // 千问API默认配置 - defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" - defaultTimeout = 10 * time.Minute - defaultRetryTimes = 2 -) - -type QwenEmbeddingConfig struct { - // Timeout specifies the maximum duration to wait for API responses - // Optional. Default: 10 minutes - Timeout *time.Duration `json:"timeout"` - - // HTTPClient specifies the client to send HTTP requests. - // Optional. Default &http.Client{Timeout: Timeout} - HTTPClient *http.Client `json:"http_client"` - - // RetryTimes specifies the number of retry attempts for failed API calls - // Optional. Default: 2 - RetryTimes *int `json:"retry_times"` - - // BaseURL specifies the base URL for Qwen DashScope service - // Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding" - BaseURL string `json:"base_url"` - - // APIKey specifies the API Key for authentication - // Required - APIKey string `json:"api_key"` - - // Model specifies the model name for Qwen embedding - // Required. Examples: "text-embedding-v2", "text-embedding-v3" - Model string `json:"model"` - - // TextType specifies the type of text: "document" or "query" - // Optional. Default: "document" - TextType string `json:"text_type"` - - // MaxConcurrentRequests specifies the maximum number of concurrent requests allowed - // Optional. Default: 5 - MaxConcurrentRequests *int `json:"max_concurrent_requests"` -} - -type QwenEmbedder struct { - client *gclient.Client - conf *QwenEmbeddingConfig -} - -// EmbeddingRequest 千问embedding请求结构 -type EmbeddingRequest struct { - Model string `json:"model"` - Input struct { - Texts []string `json:"texts"` - } `json:"input"` - Parameters struct { - TextType string `json:"text_type,omitempty"` - } `json:"parameters,omitempty"` -} - -// EmbeddingResponse 千问embedding响应结构 -type EmbeddingResponse struct { - Output struct { - Embeddings []struct { - TextIndex int `json:"text_index"` - Embedding []float64 `json:"embedding"` - } `json:"embeddings"` - } `json:"output"` - Usage struct { - TotalTokens int `json:"total_tokens"` - } `json:"usage"` - RequestID string `json:"request_id"` -} - -type APIError struct { - Code string `json:"code"` - Message string `json:"message"` - RequestID string `json:"request_id"` -} - -func (e *APIError) Error() string { - return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID) -} - -func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client { - if len(config.BaseURL) == 0 { - config.BaseURL = defaultBaseURL - } - if config.Timeout == nil { - config.Timeout = &defaultTimeout - } - if config.RetryTimes == nil { - defaultRetryTimes := 2 - config.RetryTimes = &defaultRetryTimes - } - if len(config.TextType) == 0 { - config.TextType = "document" - } - if config.MaxConcurrentRequests == nil { - defaultMaxConcurrentRequests := 5 - config.MaxConcurrentRequests = &defaultMaxConcurrentRequests - } - - client := g.Client() - client.SetTimeout(*config.Timeout) - - return client -} - -func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) { - if len(config.APIKey) == 0 { - return nil, fmt.Errorf("[Qwen] APIKey is required") - } - if len(config.Model) == 0 { - return nil, fmt.Errorf("[Qwen] Model is required") - } - - client := buildQwenClient(config) - - return &QwenEmbedder{ - client: client, - conf: config, - }, nil -} - -func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ( - [][]float64, error) { - - if len(texts) == 0 { - return nil, fmt.Errorf("[Qwen] texts cannot be empty") - } - - options := embedding.GetCommonOptions(&embedding.Options{ - Model: &e.conf.Model, - }, opts...) - - conf := &embedding.Config{ - Model: dereferenceOrZero(options.Model), - } - - ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding) - ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{ - Texts: texts, - Config: conf, - }) - defer func() { - if err := recover(); err != nil { - callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err)) - } - }() - - var usage *embedding.TokenUsage - var embeddings [][]float64 - var err error - - // 调用千问API获取embedding - embeddings, usage, err = e.callEmbeddingAPI(ctx, texts) - if err != nil { - callbacks.OnError(ctx, err) - return nil, err - } - - callbacks.OnEnd(ctx, &embedding.CallbackOutput{ - Embeddings: embeddings, - Config: conf, - TokenUsage: usage, - }) - - return embeddings, nil -} - -func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) { - // 构建请求 - var req EmbeddingRequest - req.Model = e.conf.Model - req.Input.Texts = texts - req.Parameters.TextType = e.conf.TextType - - // 调用API - client := e.client.Clone() - client.SetHeader("Authorization", "Bearer "+e.conf.APIKey) - client.SetHeader("Content-Type", "application/json") - client.SetTimeout(*e.conf.Timeout) - - resp, err := client.Post(ctx, e.conf.BaseURL, req) - if err != nil { - return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err) - } - - defer resp.Close() - - // 检查状态码 - if resp.StatusCode != http.StatusOK { - var errResp APIError - result := resp.ReadAll() - if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" { - return nil, nil, &errResp - } - return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode) - } - - // 解析响应 - var apiResp EmbeddingResponse - result := resp.ReadAll() - if err = gconv.Struct(result, &apiResp); err != nil { - return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err) - } - - // 解析响应结果 - embeddings := make([][]float64, len(texts)) - for _, emb := range apiResp.Output.Embeddings { - if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) { - embeddings[emb.TextIndex] = emb.Embedding - } - } - - usage := &embedding.TokenUsage{ - TotalTokens: apiResp.Usage.TotalTokens, - } - - g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens) - - return embeddings, usage, nil -} - -func (e *QwenEmbedder) GetType() string { - return getType() -} - -func (e *QwenEmbedder) IsCallbacksEnabled() bool { - return true -} - -func getType() string { - return "Qwen" -} - -func dereferenceOrZero[T any](v *T) T { - if v == nil { - var t T - return t - } - return *v -} diff --git a/common/eino/priority_enum.go b/common/eino/priority_enum.go deleted file mode 100644 index 371706b..0000000 --- a/common/eino/priority_enum.go +++ /dev/null @@ -1,11 +0,0 @@ -package eino - -// TaskPriority 任务优先级 -type TaskPriority string - -const ( - TaskPriorityLow TaskPriority = "low" // 低优先级 - TaskPriorityMedium TaskPriority = "medium" // 中优先级 - TaskPriorityHigh TaskPriority = "high" // 高优先级 - TaskPriorityUrgent TaskPriority = "urgent" // 紧急 -) diff --git a/common/eino/retriever.go b/common/eino/retriever.go index a74e313..84b1505 100644 --- a/common/eino/retriever.go +++ b/common/eino/retriever.go @@ -3,6 +3,8 @@ package eino import ( "context" "errors" + "rag/dao" + "sort" "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/embedding" @@ -16,12 +18,14 @@ type PGVectorRetrieverConfig struct { Embedder embedding.Embedder DefaultTopK int DefaultIndex string + DSLInfo map[string]any } type PGVectorRetriever struct { embedder embedding.Embedder topK int index string + dslInfo map[string]any } func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) { @@ -36,43 +40,62 @@ func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, embedder: config.Embedder, topK: config.DefaultTopK, index: config.DefaultIndex, + dslInfo: config.DSLInfo, }, nil } func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) { - - // 1. 处理公共 Option(官方标准写法) options := &retriever.Options{ Index: &r.index, TopK: &r.topK, + DSLInfo: r.dslInfo, Embedding: r.embedder, } options = retriever.GetCommonOptions(options, opts...) - // 2. 回调(官方标准) ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{ Query: query, TopK: *options.TopK, }) - // 3. 执行检索 - docs, err := r.doRetrieve(ctx, query, options) + // ========================================== + // 🔥 双路检索:向量 + 全文 + // ========================================== + docsVector, err := r.doRetrieveVector(ctx, query, options) if err != nil { callbacks.OnError(ctx, err) return nil, err } - // 4. 完成回调 - callbacks.OnEnd(ctx, &retriever.CallbackOutput{ - Docs: docs, + docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options) + if err != nil { + callbacks.OnError(ctx, err) + return nil, err + } + + // 合并 + 去重 + docs := mergeAndDeduplicate(docsVector, docsFulltext) + + // 排序(distance 越小越靠前) + sort.Slice(docs, func(i, j int) bool { + d1 := gconv.Float64(docs[i].MetaData["distance"]) + d2 := gconv.Float64(docs[j].MetaData["distance"]) + return d1 < d2 }) + // 最多保留 topK + if len(docs) > *options.TopK { + docs = docs[:*options.TopK] + } + + callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs}) return docs, nil } -func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { - - // 1. 生成向量 +// ========================================== +// 1. 向量检索(PG) +// ========================================== +func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query}) if err != nil { return nil, err @@ -81,37 +104,76 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts * return nil, errors.New("empty query vector") } - queryVec := pgvector.NewVector(vectors[0]) + queryVec := pgvector.NewVector(gconv.Float32s(vectors[0])) topK := *opts.TopK + datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) - // 2. PG 向量相似度检索 SQL - sql := ` - SELECT id, content, dataset_id, document_id, - vector <-> ? AS distance - FROM document_chunk - ORDER BY distance ASC - LIMIT ? -` - - // 3. 查询 - rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK) + rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK) if err != nil { return nil, err } - // 4. 转为 Eino Document docs := make([]*schema.Document, 0, len(rows)) for _, row := range rows { docs = append(docs, &schema.Document{ ID: gconv.String(row["id"]), Content: gconv.String(row["content"]), - Metadata: map[string]any{ - "dataset_id": row["dataset_id"], - "document_id": row["document_id"], - "distance": row["distance"], + MetaData: map[string]any{ + "dataset_id": gconv.Int64(row["dataset_id"]), + "document_id": gconv.Int64(row["document_id"]), + "distance": gconv.Float64(row["distance"]), + "retrieve_by": "vector", }, }) } - return docs, nil } + +// ========================================== +// 2. 全文检索(Meilisearch)🔥 新增 +// ========================================== +func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) { + topK := *opts.TopK + datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"]) + + // 调用你已有的 Meilisearch DAO + rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK) + if err != nil { + return nil, err + } + + docs := make([]*schema.Document, 0, len(rows)) + for _, row := range rows { + docs = append(docs, &schema.Document{ + ID: gconv.String(row["id"]), + Content: gconv.String(row["content"]), + MetaData: map[string]any{ + "dataset_id": gconv.Int64(row["dataset_id"]), + "document_id": gconv.Int64(row["document_id"]), + "distance": 0.1, // 全文结果给高分 + "retrieve_by": "fulltext", + }, + }) + } + return docs, nil +} + +// ========================================== +// 合并去重 +// ========================================== +func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document { + idMap := make(map[string]*schema.Document) + for _, d := range vecDocs { + idMap[d.ID] = d + } + for _, d := range fullDocs { + if _, exists := idMap[d.ID]; !exists { + idMap[d.ID] = d + } + } + merged := make([]*schema.Document, 0, len(idMap)) + for _, d := range idMap { + merged = append(merged, d) + } + return merged +} diff --git a/common/eino/status_enum.go b/common/eino/status_enum.go deleted file mode 100644 index 6e12daf..0000000 --- a/common/eino/status_enum.go +++ /dev/null @@ -1,12 +0,0 @@ -package eino - -// TaskStatus 任务状态 -type TaskStatus string - -const ( - TaskStatusPending TaskStatus = "pending" // 待处理 - TaskStatusRunning TaskStatus = "running" // 运行中 - TaskStatusCompleted TaskStatus = "completed" // 已完成 - TaskStatusFailed TaskStatus = "failed" // 失败 - TaskStatusCancelled TaskStatus = "cancelled" // 已取消 -) diff --git a/common/eino/task_type.go b/common/eino/task_type.go deleted file mode 100644 index 0ba5a64..0000000 --- a/common/eino/task_type.go +++ /dev/null @@ -1,14 +0,0 @@ -package eino - -// TaskType 任务类型 -type TaskType string - -const ( - TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务 - TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务 - TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务 - TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务 - TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务 - TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务 - TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务 -) diff --git a/common/gse/utils.go b/common/gse/utils.go deleted file mode 100644 index aea4b38..0000000 --- a/common/gse/utils.go +++ /dev/null @@ -1,114 +0,0 @@ -package gse - -import ( - "context" - "sort" - - "github.com/go-ego/gse" - "github.com/go-ego/gse/hmm/extracker" - "github.com/go-ego/gse/hmm/segment" - "github.com/gogf/gf/v2/os/glog" -) - -var GseTool *gseTool - -// 初始化函数:程序启动时执行一次 -func init() { - var err error - GseTool, err = newGseTool() - if err != nil { - glog.Error(context.Background(), err) - } -} - -// gseTool 关键词提取工具(gse v1.0.2 标准) -type gseTool struct { - seg gse.Segmenter - tfidf *extracker.TagExtracter - tr *extracker.TextRanker -} - -// newGseTool 初始化工具(内置词典 + 停用词) -func newGseTool() (tool *gseTool, err error) { - // 1. 初始化分词器 - var seg gse.Segmenter - // 内置词典(无外部文件) - err = seg.LoadDictEmbed() - if err != nil { - return - } - // 内置停用词(v1.0.2 标准) - err = seg.LoadStopEmbed() - if err != nil { - return - } - - // 2. 初始化 TF-IDF 提取器 - tfidf := &extracker.TagExtracter{} - tfidf.WithGse(seg) - err = tfidf.LoadIdf() - if err != nil { - return - } - - // 3. 初始化 TextRank 提取器 - tr := &extracker.TextRanker{} - tr.WithGse(seg) - - tool = &gseTool{ - seg: seg, - tfidf: tfidf, - tr: tr, - } - return -} - -// Cut 分词(关键词提取唯一正确模式:精确模式 + HMM) -func (k *gseTool) Cut(text string) []string { - return k.seg.Cut(text, true) -} - -// Keyword 最终输出:关键词 + 权重 -type Keyword struct { - Word string `json:"word"` - Score float64 `json:"score"` -} - -func (k *gseTool) Extract(text string, topN int) []Keyword { - // 1. 提取 TF-IDF - tfTags := k.extractTFIDF(text, topN) - - // 2. 提取 TextRank - trTags := k.extractTextRank(text, topN) - - // 3. 合并成最终关键词(业务最常用) - scoreMap := make(map[string]float64) - for _, tag := range tfTags { - scoreMap[tag.Text] = tag.Weight - } - for _, tag := range trTags { - scoreMap[tag.Text] = tag.Weight - } - - // 转成切片并排序(高分在前) - res := make([]Keyword, 0, len(scoreMap)) - for word, score := range scoreMap { - res = append(res, Keyword{Word: word, Score: score}) - } - - sort.Slice(res, func(i, j int) bool { - return res[i].Score > res[j].Score - }) - - return res -} - -// ExtractTFIDF TF-IDF 关键词(带权重)90% 业务:文章标签、搜索、关键词 -func (k *gseTool) extractTFIDF(text string, topN int) segment.Segments { - return k.tfidf.ExtractTags(text, topN) -} - -// ExtractTextRank TextRank 关键词(带权重)长文本、摘要、语义理解 -func (k *gseTool) extractTextRank(text string, topN int) segment.Segments { - return k.tr.TextRank(text, topN) -} diff --git a/common/task/base_task.go b/common/task/base_task.go new file mode 100644 index 0000000..8e07d5c --- /dev/null +++ b/common/task/base_task.go @@ -0,0 +1,69 @@ +package task + +import ( + "time" + + "gitea.com/red-future/common/beans" +) + +type baseTaskCol struct { + beans.SQLBaseCol + TaskType string + Status string + Priority string + ParentTaskID string + TotalItems string + ProcessedItems string + Progress string + StartTime string + EndTime string + Duration string + SuccessCount string + FailCount string + Executor string + DocumentID string + Remark string +} + +var BaseTaskCol = baseTaskCol{ + SQLBaseCol: beans.DefSQLBaseCol, + TaskType: "task_type", + Status: "status", + Priority: "task_priority", + ParentTaskID: "parent_task_id", + TotalItems: "total_items", + ProcessedItems: "processed_items", + Progress: "progress", + StartTime: "start_time", + EndTime: "end_time", + Duration: "duration", + SuccessCount: "success_count", + FailCount: "fail_count", + Executor: "executor", + DocumentID: "document_id", + Remark: "remark", +} + +// SQLBaseTask 任务基类 - SQL版本 +type SQLBaseTask struct { + beans.SQLBaseDO `orm:",inline"` + // 任务核心信息 + TaskType TaskType `orm:"task_type" json:"taskType" dc:"任务类型"` + Status TaskStatus `orm:"status" json:"status" dc:"任务状态"` + Priority TaskPriority `orm:"task_priority" json:"priority,omitempty" dc:"任务优先级"` + ParentTaskID int64 `orm:"parent_task_id" json:"parentTaskId,omitempty" dc:"父任务ID"` + // 任务进度 + TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总数"` + ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理数"` + Progress float64 `orm:"progress" json:"progress" dc:"进度"` // 0~100 百分比 + // 任务结果 + StartTime *time.Time `orm:"start_time" json:"startTime" dc:"开始时间"` + EndTime *time.Time `orm:"end_time" json:"endTime,omitempty" dc:"结束时间"` + Duration int64 `orm:"duration" json:"duration,omitempty" dc:"耗时(毫秒)"` + SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"` + FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"` + // 其他 + Executor string `orm:"executor" json:"executor,omitempty" dc:"执行器标识"` + DocumentID int64 `orm:"document_id" json:"documentId,omitempty" dc:"文档ID"` + Remark string `orm:"remark" json:"remark,omitempty" dc:"备注/错误信息"` +} diff --git a/common/task/consts.go b/common/task/consts.go new file mode 100644 index 0000000..9dc07d1 --- /dev/null +++ b/common/task/consts.go @@ -0,0 +1,30 @@ +package task + +// TaskType 任务类型枚举:文档解析的三个子任务 +type TaskType string + +const ( + TaskTypeExtractKeywords TaskType = "EXTRACT_KEYWORDS" // 提取关键词 + TaskTypeGenerateVector TaskType = "GENERATE_VECTOR" // 生成向量 + TaskTypeFullTextSearch TaskType = "FULL_TEXT_SEARCH" // 全文检索 + TaskTypeDocParse TaskType = "DOC_PARSE" // 顶层文档解析总任务 +) + +// TaskStatus 任务状态枚举 +type TaskStatus string + +const ( + TaskStatusPending TaskStatus = "PENDING" // 待执行 + TaskStatusRunning TaskStatus = "RUNNING" // 执行中 + TaskStatusCompleted TaskStatus = "COMPLETED" // 已完成 + TaskStatusFailed TaskStatus = "FAILED" // 执行失败 +) + +// TaskPriority 任务优先级 +type TaskPriority int + +const ( + TaskPriorityLow TaskPriority = 1 // 低 + TaskPriorityMedium TaskPriority = 2 // 中 + TaskPriorityHigh TaskPriority = 3 // 高 +) diff --git a/config.yml b/config.yml index 93fddcb..f15927b 100644 --- a/config.yml +++ b/config.yml @@ -48,10 +48,10 @@ database: timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效 rag_knowledge: - type: "pgsql" - host: "localhost" - port: "5432" + host: "116.204.74.41" + port: "15432" user: "postgres" - pass: "123456" + pass: "Bjang09@686^*^" name: "tenant-1" prefix: "rag_knowledge_" # (可选)表名前缀 role: "master" @@ -69,10 +69,10 @@ database: timeMaintainDisabled: false # (可选)是否完全关闭时间更新特性,为true时CreatedAt/UpdatedAt/DeletedAt都将失效 rag_vector: - type: "pgsql" - host: "localhost" - port: "5432" + host: "116.204.74.41" + port: "15432" user: "postgres" - pass: "123456" + pass: "Bjang09@686^*^" name: "tenant-1" prefix: "rag_vector_" # (可选)表名前缀 role: "master" @@ -91,14 +91,14 @@ database: redis: default: - address: "localhost:6379" + address: "116.204.74.41:6379" db: 0 consul: - address: localhost:8500 + address: 116.204.74.41:8500 jaeger: - addr: localhost:4318 + addr: 116.204.74.41:4318 # eino框架配置 eino: @@ -115,6 +115,10 @@ eino: # apiType: "multi_modal_api" apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9" model: "text-embedding-v3" + chatmodel: + provider: "dashscope" + apiKey: "sk-4a8b82770bf74bc490eb3e4c5a8e2be9" + model: "qwen-turbo" # 文件上传服务地址,与oss模块minio中的endpoint一致 filePrefix: "http://116.204.74.41:9000" @@ -122,7 +126,7 @@ filePrefix: "http://116.204.74.41:9000" gmq: redis: primary: - addr: "localhost" + addr: "116.204.74.41" port: "6379" db: 0 username: "" diff --git a/consts/public/table_name.go b/consts/public/table_name.go index e9a6ef3..e31d1e0 100644 --- a/consts/public/table_name.go +++ b/consts/public/table_name.go @@ -1,5 +1,6 @@ package public +// 数据库名称 const ( DbNameKnowledge = "rag_knowledge" DbNameVector = "rag_vector" @@ -10,6 +11,7 @@ const ( TableNameDocument = "document" TableNameDataset = "dataset" TableNameKeyword = "keyword" + TableNameTask = "task" TableNameDatasetIndex = "dataset_index" TableNameDocumentChunk = "document_chunk" ) diff --git a/controller/document.go b/controller/document.go index 0017ddc..e5bb376 100644 --- a/controller/document.go +++ b/controller/document.go @@ -48,7 +48,7 @@ func (c *document) List(ctx context.Context, req *dto.ListDocumentReq) (res *dto } // Process 处理文件(向量化) -func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) { - res, err = service.Document.Process(ctx, req) +func (c *document) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *beans.ResponseEmpty, err error) { + err = service.Document.Process(ctx, req) return } diff --git a/controller/rag_query.go b/controller/rag_query.go new file mode 100644 index 0000000..a480566 --- /dev/null +++ b/controller/rag_query.go @@ -0,0 +1,17 @@ +package controller + +import ( + "context" + "rag/model/dto" + "rag/service" +) + +type ragQuery struct{} + +var RAGQuery = new(ragQuery) + +// Query 执行RAG查询 +func (c *ragQuery) Query(ctx context.Context, req *dto.RAGQueryReq) (res *dto.RAGQueryRes, err error) { + res, err = service.RAGQuery.Query(ctx, req) + return +} diff --git a/dao/dataset_index.go b/dao/dataset_index.go index fa9c078..96f336c 100644 --- a/dao/dataset_index.go +++ b/dao/dataset_index.go @@ -49,6 +49,7 @@ func (d *datasetIndexDao) InsertIndex(ctx context.Context, indexName string) (er CREATE INDEX IF NOT EXISTS %s ON %s USING ivfflat (vector vector_cosine_ops) + WITH (lists = 100) WHERE vector IS NOT NULL; `, indexName, gfdb.TablePrefix+public.TableNameDocumentChunk) _, err = db.Exec(ctx, sqlStr) diff --git a/dao/document_chunk.go b/dao/document_chunk.go index c39501f..218cf70 100644 --- a/dao/document_chunk.go +++ b/dao/document_chunk.go @@ -2,12 +2,17 @@ package dao import ( "context" + "fmt" "rag/consts/public" "rag/model/dto" "rag/model/entity" "gitea.com/red-future/common/db/gfdb" + "gitea.com/red-future/common/full-text-search/meilisearch" + "github.com/gogf/gf/v2/database/gdb" + "github.com/gogf/gf/v2/text/gstr" "github.com/gogf/gf/v2/util/gconv" + "github.com/pgvector/pgvector-go" ) var DocumentChunk = new(documentChunkDao) @@ -55,3 +60,56 @@ func (d *documentChunkDao) List(ctx context.Context, req *dto.ListDocumentChunkR err = r.Structs(&res) return } + +func (d *documentChunkDao) GetAllByVector(ctx context.Context, datasetId []int64, queryVec pgvector.Vector, topK int) (list gdb.List, err error) { + sql := ` + SELECT id, content, dataset_id, document_id, + vector <-> ? AS distance + FROM rag_vector_document_chunk + WHERE dataset_id IN (?) + AND vector IS NOT NULL + ORDER BY distance ASC + LIMIT ? +` + // 顺序:vector, dataset_id, topK + result, err := gfdb.DB(ctx, public.DbNameVector).GetAll(ctx, sql, queryVec, datasetId, topK) + if err != nil { + return nil, err + } + + return result.List(), nil +} + +// SearchByKeywords 通过关键词全文检索文档块 +func (d *documentChunkDao) SearchByKeywords(ctx context.Context, query string, datasetIds []int64, topK int) (list gdb.List, err error) { + // 构建 meilisearch 查询参数 + searchParams := &meilisearch.SearchParams{ + Query: query, + Limit: int64(topK), + } + + // 构建 datasetIds 过滤条件 + if len(datasetIds) > 0 { + datasetIdStrs := gconv.Strings(datasetIds) + quotedIds := make([]string, len(datasetIdStrs)) + for i, id := range datasetIdStrs { + quotedIds[i] = fmt.Sprintf("%s", id) + } + searchParams.Filter = fmt.Sprintf("dataset_id IN [%s]", gstr.Implode(", ", quotedIds)) + } + + // 执行搜索 + var hits []map[string]interface{} + _, err = meilisearch.DB().Search(ctx, searchParams, public.IndexNameDocumentChunk, &hits) + if err != nil { + return nil, err + } + + // 转换查询结果为 gdb.List + resultList := make(gdb.List, 0, len(hits)) + for _, hit := range hits { + resultList = append(resultList, hit) + } + + return resultList, nil +} diff --git a/dao/keyword.go b/dao/keyword.go index 156648c..a0cc8fe 100644 --- a/dao/keyword.go +++ b/dao/keyword.go @@ -82,6 +82,9 @@ func (d *keywordDao) List(ctx context.Context, req *dto.ListKeywordReq, fields . if !g.IsEmpty(req.Keyword) { model.WhereLike(entity.KeywordCol.Word, "%"+req.Keyword+"%") } + model.WhereIn(entity.KeywordCol.Word, req.Words) + model.Where(entity.KeywordCol.DatasetId, req.DatasetId) + model.Where(entity.KeywordCol.DocumentId, req.DocumentId) model.OrderDesc(entity.KeywordCol.Weight) model.OrderDesc(entity.KeywordCol.CreatedAt) if req.Page != nil { diff --git a/dao/task.go b/dao/task.go new file mode 100644 index 0000000..acf08c1 --- /dev/null +++ b/dao/task.go @@ -0,0 +1,58 @@ +package dao + +import ( + "context" + "rag/consts/public" + "rag/model/dto" + "rag/model/entity" + + "gitea.com/red-future/common/db/gfdb" + "github.com/gogf/gf/v2/util/gconv" +) + +var Task = new(taskDao) + +type taskDao struct{} + +// Insert 创建任务 +func (d *taskDao) Insert(ctx context.Context, req *dto.CreateTaskReq) (id int64, err error) { + var res *entity.Task + if err = gconv.Struct(req, &res); err != nil { + return + } + r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Data(&res).Insert() + if err != nil { + return + } + return r.LastInsertId() +} + +// Update 更新任务 +func (d *taskDao) Update(ctx context.Context, req *dto.UpdateTaskReq) (rows int64, err error) { + model := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask) + r, err := model.Data(&req).Where(entity.TaskCol.Id, req.Id).Where(entity.TaskCol.TaskId, req.TaskId).OmitEmpty().Update() + if err != nil { + return + } + return r.RowsAffected() +} + +func (d *taskDao) Get(ctx context.Context, req *dto.GetTaskReq) (res []*entity.Task, total int, err error) { + r, total, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).OmitEmpty(). + Where(entity.TaskCol.Id, req.Id). + Where(entity.TaskCol.TaskId, req.TaskId). + Where(entity.TaskCol.TaskType, req.TaskType).AllAndCount(false) + if err != nil { + return + } + err = r.Structs(&res) + return +} + +func (d *taskDao) DeleteByTaskId(ctx context.Context, req *dto.DeleteTaskByTaskIdReq) (rows int64, err error) { + r, err := gfdb.DB(ctx, public.DbNameKnowledge).Model(ctx, public.TableNameTask).Where(entity.TaskCol.TaskId, req.TaskId).Delete() + if err != nil { + return + } + return r.RowsAffected() +} diff --git a/go.mod b/go.mod index dd457bd..7a29f26 100644 --- a/go.mod +++ b/go.mod @@ -16,9 +16,9 @@ require ( github.com/cloudwego/eino-ext/components/embedding/dashscope v0.0.0-20260323112355-f061db7e8419 github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419 github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 + github.com/cloudwego/eino-ext/components/model/qwen v0.1.7 github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 github.com/elastic/go-elasticsearch/v8 v8.16.0 - github.com/go-ego/gse v1.0.2 github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.0 github.com/gogf/gf/v2 v2.10.0 github.com/golang/glog v1.2.5 @@ -50,7 +50,7 @@ require ( github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/eino-ext/components/document/parser/html v0.0.0-20241224063832-9fbcc0e56c28 // indirect - github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 // indirect + github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 // indirect github.com/dgraph-io/badger/v4 v4.2.0 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -64,6 +64,7 @@ require ( github.com/fatih/color v1.19.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.13 // indirect + github.com/go-ego/gse v1.0.2 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -105,7 +106,7 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.21 // indirect - github.com/meguminnnnnnnnn/go-openai v0.1.1 // indirect + github.com/meguminnnnnnnnn/go-openai v0.1.2 // indirect github.com/meilisearch/meilisearch-go v0.36.1 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect @@ -134,7 +135,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/vcaesar/cedar v0.30.0 // indirect github.com/volcengine/volc-sdk-golang v1.0.199 // indirect - github.com/volcengine/volcengine-go-sdk v1.0.181 // indirect + github.com/volcengine/volcengine-go-sdk v1.2.9 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xuri/efp v0.0.0-20240408161823-9ad904a10d6d // indirect github.com/xuri/excelize/v2 v2.9.0 // indirect diff --git a/go.sum b/go.sum index d31f601..4d8bdf9 100644 --- a/go.sum +++ b/go.sum @@ -33,8 +33,6 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9 dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= entgo.io/ent v0.14.3 h1:wokAV/kIlH9TeklJWGGS7AYJdVckr0DloWjIcO9iIIQ= entgo.io/ent v0.14.3/go.mod h1:aDPE/OziPEu8+OWbzy4UlvWmD2/kbRuWfK2A40hcxJM= -gitea.com/red-future/common v0.0.11 h1:AV7W3G0uZ8aPpHHSHd4ZHmLWe5+2STPKe/AYPoPCWVc= -gitea.com/red-future/common v0.0.11/go.mod h1:B8syUI4XbLCDQSeRHURYxEwnWw8mEFgmqCxjC+lM+NU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/BurntSushi/toml v1.6.0 h1:dRaEfpa2VI55EwlIW72hMRHdWouJeRF7TPYhI+AUQjk= @@ -158,10 +156,12 @@ github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355- github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260323112355-f061db7e8419/go.mod h1:SajSFFRIXJXIbxadAAlSUIS5KTY8R/jzJg9RNSOXCCI= github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9 h1:vZ3dL8xwo2sy73aBVKs4AJiO5OCHRxMOJUwIYkp0CWs= github.com/cloudwego/eino-ext/components/indexer/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:+oI0sr0rA0OHCxaQJ0rzMYld3LAODHhPKzBx5JYCya0= +github.com/cloudwego/eino-ext/components/model/qwen v0.1.7 h1:8c1LB5lH+dERbf2twp18B1Y822JOQSsS6x7Vnksehk0= +github.com/cloudwego/eino-ext/components/model/qwen v0.1.7/go.mod h1:n4iuIUQeL3D8GRsGAhkgceRZpoyPQbqOXFMXM2Q4hNY= github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9 h1:Sl6giB1SJlA+ZlO0gzPH05IsUORtdYYPN6GiyH1B9MA= github.com/cloudwego/eino-ext/components/retriever/es8 v0.0.0-20260331071634-4f359694d2d9/go.mod h1:H4kNmiTe2irnvipVNIP4q8yqXf2fZ6v24krvQYBtYb8= -github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14 h1:yOZII6VYaL00CVZYba+HUixFygsW0Xz/1QjQ5htj1Ls= -github.com/cloudwego/eino-ext/libs/acl/openai v0.1.14/go.mod h1:1xMQZ8eE11pkEoTAEy8UlaAY817qGVMvjpDPGSIO3Ns= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15 h1:LbdSG9+qWzzp9RFW6dSFkaUW171JvCoYn/K63zX6dQE= +github.com/cloudwego/eino-ext/libs/acl/openai v0.1.15/go.mod h1:p+l0zBB0GjjX8HTlbTs3g3KfUFwZC11bsCGZOXW/3L0= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= @@ -531,8 +531,8 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D github.com/mattn/go-runewidth v0.0.21 h1:jJKAZiQH+2mIinzCJIaIG9Be1+0NR+5sz/lYEEjdM8w= github.com/mattn/go-runewidth v0.0.21/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/meguminnnnnnnnn/go-openai v0.1.1 h1:u/IMMgrj/d617Dh/8BKAwlcstD74ynOJzCtVl+y8xAs= -github.com/meguminnnnnnnnn/go-openai v0.1.1/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= +github.com/meguminnnnnnnnn/go-openai v0.1.2 h1:iXombGGjqjBrmE9WaSidUhhi3YQhf42QTHvHLMkgvCA= +github.com/meguminnnnnnnnn/go-openai v0.1.2/go.mod h1:qs96ysDmxhE4BZoU45I43zcyfnaYxU3X+aRzLko/htY= github.com/meilisearch/meilisearch-go v0.36.1 h1:mJTCJE5g7tRvaqKco6DfqOuJEjX+rRltDEnkEC02Y0M= github.com/meilisearch/meilisearch-go v0.36.1/go.mod h1:hWcR0MuWLSzHfbz9GGzIr3s9rnXLm1jqkmHkJPbUSvM= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b h1:j7+1HpAFS1zy5+Q4qx1fWh90gTKwiN4QCGoY9TWyyO4= @@ -735,8 +735,8 @@ github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU= github.com/volcengine/volc-sdk-golang v1.0.199 h1:zv9QOqTl/IsLwtfC37GlJtcz6vMAHi+pjq8ILWjLYUc= github.com/volcengine/volc-sdk-golang v1.0.199/go.mod h1:stZX+EPgv1vF4nZwOlEe8iGcriUPRBKX8zA19gXycOQ= -github.com/volcengine/volcengine-go-sdk v1.0.181 h1:/3PB4M1N4fjMqiSKTJwX43EZ5Nn1HUOtQrSCk+22+wI= -github.com/volcengine/volcengine-go-sdk v1.0.181/go.mod h1:gfEDc1s7SYaGoY+WH2dRrS3qiuDJMkwqyfXWCa7+7oA= +github.com/volcengine/volcengine-go-sdk v1.2.9 h1:du2gnImtyWXKkQFnJW/GXCs+UBibGGOXIbP1Ams2pB8= +github.com/volcengine/volcengine-go-sdk v1.2.9/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= @@ -1193,6 +1193,7 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= diff --git a/main.go b/main.go index b706aec..351275f 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "gitea.com/red-future/common/http" "gitea.com/red-future/common/jaeger" + "gitea.com/red-future/common/utils" gmq "github.com/bjang03/gmq/core/gmq" "github.com/bjang03/gmq/mq" "github.com/bjang03/gmq/types" @@ -27,22 +28,17 @@ func main() { controller.Dataset, controller.Document, controller.DocumentChunk, + controller.Keyword, + controller.RAGQuery, }) - gmq.Init("config.yml") - - if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{ - SubMessage: types.SubMessage{ - Topic: public.KnowledgeDocumentVectorStatusTopic, - ConsumerName: public.KnowledgeDocumentVectorStatusConsumer, - AutoAck: public.KnowledgeDocumentVectorStatusAutoAck, - FetchCount: public.KnowledgeDocumentVectorStatusBatchSize, - HandleFunc: service.Document.DocsVectorStatusMsg, - }, - }); err != nil { - return + err := utils.InitGseTool(ctx) + if err != nil { + g.Log().Error(ctx, "gse 分词工具初始化失败:", err) } + gmq.Init("config.yml") + if err := gmq.GetGmq("primary").GmqSubscribe(ctx, &mq.RedisSubMessage{ SubMessage: types.SubMessage{ Topic: public.KnowledgeDocumentChunkTopic, diff --git a/model/dto/document.go b/model/dto/document.go index 1c84dab..89d0069 100644 --- a/model/dto/document.go +++ b/model/dto/document.go @@ -84,12 +84,6 @@ type ProcessDocumentReq struct { DatasetId int64 `json:"datasetId" v:"required#数据集ID不能为空"` } -// ProcessDocumentRes 处理文件响应 -type ProcessDocumentRes struct { - ChunkCount int64 `json:"chunkCount"` - CostTime int64 `json:"costTime"` -} - type ListDocumentChunkRPC struct { List []*DocumentChunkRPC `json:"list"` } diff --git a/model/dto/keyword.go b/model/dto/keyword.go index 950c504..b6ca6dc 100644 --- a/model/dto/keyword.go +++ b/model/dto/keyword.go @@ -52,6 +52,7 @@ type ListKeywordReq struct { DatasetId int64 `json:"datasetId"` DocumentId int64 `json:"documentId"` Word string `json:"word"` + Words []string `json:"words"` Keyword string `json:"keyword" dc:"关键词搜索"` } @@ -62,9 +63,11 @@ type ListKeywordRes struct { } type KeywordVO struct { - Id int64 `json:"id,string" dc:"id"` - Word string `json:"word" dc:"关键词名称"` - Weight int16 `json:"weight" dc:"权重"` - CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"` - UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"` + Id int64 `json:"id,string" dc:"id"` + Word string `json:"word" dc:"关键词名称"` + Weight int16 `json:"weight" dc:"权重"` + DatasetId int64 `json:"datasetId,string" dc:"数据集ID"` + DocumentId int64 `json:"documentId,string" dc:"文档ID"` + CreatedAt *gtime.Time `json:"createdAt" dc:"创建时间"` + UpdatedAt *gtime.Time `json:"updatedAt" dc:"更新时间"` } diff --git a/model/dto/rag_query.go b/model/dto/rag_query.go new file mode 100644 index 0000000..8682c03 --- /dev/null +++ b/model/dto/rag_query.go @@ -0,0 +1,21 @@ +package dto + +import ( + "github.com/gogf/gf/v2/frame/g" +) + +// RAGQueryReq RAG查询请求 +type RAGQueryReq struct { + g.Meta `path:"/ragQuery" method:"post" tags:"RAG查询" summary:"执行RAG查询" dc:"执行RAG查询"` + + Content string `json:"content" v:"required#查询内容不能为空" dc:"用户问题"` + DatasetIds []int64 `json:"datasetIds" dc:"数据集ID"` + TopK int `json:"topK" d:"5" dc:"检索topK,默认5"` +} + +// RAGQueryRes RAG查询响应 +type RAGQueryRes struct { + Answer string `json:"answer" dc:"生成的答案"` + DatasetId string `json:"datasetId" dc:"使用的数据集ID"` + Sources []string `json:"sources" dc:"参考来源"` +} diff --git a/model/dto/task.go b/model/dto/task.go new file mode 100644 index 0000000..9e52e4c --- /dev/null +++ b/model/dto/task.go @@ -0,0 +1,65 @@ +package dto + +import ( + "rag/common/task" +) + +// WriteTaskProgressReq 写入任务进度请求 +type WriteTaskProgressReq struct { + TaskType task.TaskType `json:"taskType" dc:"任务类型"` + Status task.TaskStatus `json:"status" dc:"任务状态"` + TaskId int64 `json:"taskId" dc:"任务ID"` + Remark string `json:"remark" dc:"备注"` +} + +// CreateTaskReq 创建任务请求 +type CreateTaskReq struct { + TaskType task.TaskType `json:"taskType" dc:"任务类型"` + Status task.TaskStatus `json:"status" dc:"任务状态"` + TaskId int64 `json:"taskId" dc:"任务ID"` + Remark string `json:"remark" dc:"备注"` +} + +// UpdateTaskReq 更新任务请求 +type UpdateTaskReq struct { + Id int64 `json:"id" dc:"任务ID"` + TaskId int64 `json:"taskId" dc:"任务ID"` + Status task.TaskStatus `json:"status" dc:"任务状态"` + Remark string `json:"remark" dc:"备注"` +} + +// DeleteTaskByTaskIdReq 删除任务请求 +type DeleteTaskByTaskIdReq struct { + TaskId int64 `json:"taskId" v:"required#任务id不能为空"` +} + +// GetTaskReq 获取任务请求 +type GetTaskReq struct { + Id int64 `json:"id" dc:"任务ID"` + TaskId int64 `json:"taskId" dc:"任务ID"` + TaskType task.TaskType `json:"taskType" dc:"任务类型"` +} + +// TaskVO 任务视图对象 +type TaskVO struct { + Id int64 `json:"id" dc:"任务ID"` + TaskType task.TaskType `json:"taskType" dc:"任务类型"` + Status task.TaskStatus `json:"status" dc:"任务状态"` + Priority task.TaskPriority `json:"priority" dc:"任务优先级"` + ParentTaskID int64 `json:"parentTaskId" dc:"父任务ID"` + TotalItems int64 `json:"totalItems" dc:"总项数"` + ProcessedItems int64 `json:"processedItems" dc:"已处理项数"` + Progress float64 `json:"progress" dc:"进度百分比"` + StartTime *int64 `json:"startTime" dc:"开始时间戳"` + EndTime *int64 `json:"endTime" dc:"结束时间戳"` + Duration int64 `json:"duration" dc:"耗时(毫秒)"` + SuccessCount int64 `json:"successCount" dc:"成功数"` + FailCount int64 `json:"failCount" dc:"失败数"` + Executor string `json:"executor" dc:"执行器"` + DocumentID int64 `json:"documentId" dc:"文档ID"` + Remark string `json:"remark" dc:"备注"` + Creator string `json:"creator" dc:"创建人"` + CreatedAt int64 `json:"createdAt" dc:"创建时间"` + Updater string `json:"updater" dc:"更新人"` + UpdatedAt int64 `json:"updatedAt" dc:"更新时间"` +} diff --git a/model/entity/task.go b/model/entity/task.go new file mode 100644 index 0000000..c06d791 --- /dev/null +++ b/model/entity/task.go @@ -0,0 +1,66 @@ +package entity + +import ( + "rag/common/task" + + "gitea.com/red-future/common/beans" +) + +type taskCol struct { + beans.SQLBaseCol + TaskId string + TaskType string + Status string + Executor string + Remark string + //Priority string + //ParentTaskId string + //TotalItems string + //ProcessedItems string + //Progress string + //StartTime string + //EndTime string + //Duration string + //SuccessCount string + //FailCount string +} + +var TaskCol = taskCol{ + SQLBaseCol: beans.DefSQLBaseCol, + TaskId: "task_id", + TaskType: "task_type", + Status: "status", + Executor: "executor", + Remark: "remark", + //Priority: "priority", + //ParentTaskId: "parent_task_id", + //TotalItems: "total_items", + //ProcessedItems: "processed_items", + //Progress: "progress", + //StartTime: "start_time", + //EndTime: "end_time", + //Duration: "duration", + //SuccessCount: "success_count", + //FailCount: "fail_count", +} + +// Task 任务记录表 +type Task struct { + beans.SQLBaseDO `orm:",inline"` + + TaskId int64 `orm:"task_id" json:"taskId" dc:"任务ID"` + TaskType task.TaskType `orm:"task_type" json:"taskType" dc:"任务类型"` + Status task.TaskStatus `orm:"status" json:"status" dc:"任务状态"` + Executor string `orm:"executor" json:"executor" dc:"执行器"` + Remark string `orm:"remark" json:"remark" dc:"备注"` + //Priority task.TaskPriority `orm:"priority" json:"priority" dc:"任务优先级"` + //ParentTaskId int64 `orm:"parent_task_id" json:"parentTaskId" dc:"父任务ID"` + //TotalItems int64 `orm:"total_items" json:"totalItems" dc:"总项数"` + //ProcessedItems int64 `orm:"processed_items" json:"processedItems" dc:"已处理项数"` + //SuccessCount int64 `orm:"success_count" json:"successCount" dc:"成功数"` + //FailCount int64 `orm:"fail_count" json:"failCount" dc:"失败数"` + //Progress float64 `orm:"progress" json:"progress" dc:"进度百分比"` + //StartTime *gtime.Time `orm:"start_time" json:"startTime" dc:"开始时间戳"` + //EndTime *gtime.Time `orm:"end_time" json:"endTime" dc:"结束时间戳"` + //Duration int64 `orm:"duration" json:"duration" dc:"耗时(毫秒)"` +} diff --git a/service/document.go b/service/document.go index 12c9964..a78448f 100644 --- a/service/document.go +++ b/service/document.go @@ -5,17 +5,14 @@ import ( "errors" "fmt" "rag/common/eino" - "rag/common/gse" + "rag/common/task" "rag/consts/document" "rag/consts/public" "rag/dao" "rag/model/dto" "rag/model/entity" "strings" - "sync" - "time" - "gitea.com/red-future/common/beans" "gitea.com/red-future/common/db/gfdb" "gitea.com/red-future/common/full-text-search/meilisearch" "gitea.com/red-future/common/http" @@ -29,6 +26,7 @@ import ( "github.com/gogf/gf/v2/database/gdb" "github.com/gogf/gf/v2/database/gredis" "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/os/grpool" "github.com/gogf/gf/v2/util/gconv" ) @@ -54,7 +52,13 @@ func (s *documentService) Create(ctx context.Context, req *dto.CreateDocumentReq return } res = &dto.CreateDocumentRes{Id: id} - + // 写入任务进度待处理 任务类型为文档解析 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: id, + TaskType: task.TaskTypeDocParse, + Status: task.TaskStatusPending, + Remark: "文档上传成功待解析: " + req.Title, + }) return }) @@ -79,11 +83,20 @@ func (s *documentService) Delete(ctx context.Context, req *dto.DeleteDocumentReq DocumentCount: -1, DocumentSize: -docs.FileSize, } - _, err = dao.Dataset.Update(ctx, datasetReq) - if err != nil { + if _, err = dao.Dataset.Update(ctx, datasetReq); err != nil { return } - _, err = dao.Document.Delete(ctx, req) + + if _, err = dao.Document.Delete(ctx, req); err != nil { + return + } + + if _, err = dao.Task.DeleteByTaskId(ctx, &dto.DeleteTaskByTaskIdReq{ + TaskId: docs.Id, + }); err != nil { + return + } + return }) @@ -107,118 +120,159 @@ func (s *documentService) List(ctx context.Context, req *dto.ListDocumentReq) (r Total: total, } err = gconv.Struct(list, &res.List) - - //eino.TestIndexer() - //eino.TestRetriever() - return } // Process 处理文件(使用eino框架切分和向量化) -func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (res *dto.ProcessDocumentRes, err error) { - startTime := time.Now() - +func (s *documentService) Process(ctx context.Context, req *dto.ProcessDocumentReq) (err error) { // 1. 查询文件信息 documentReq := dto.GetDocumentReq{Id: req.Id} doc, err := dao.Document.GetByID(ctx, &documentReq) if err != nil { - return nil, err + return err } if g.IsEmpty(doc) { - return nil, errors.New("document not found") + return errors.New("document not found") } - // 2. 使用eino框架进行文件切分(并发执行) - var vectorDocsCount, chunks int64 - // 用 gopool 或者简单的错误等待,绝对不用裸 goroutine - var err1, err2, err3 error - var wg sync.WaitGroup - wg.Add(3) - - // 任务1 - go func() { - defer wg.Done() - vectorDocsCount, chunks, err1 = s.sqlSplitDocument(ctx, doc) - }() - - // 任务2 - go func() { - defer wg.Done() - err2 = s.esSplitDocument(ctx, doc) - }() - - // 任务3 - go func() { - defer wg.Done() - err3 = s.extractDocument(ctx, doc) - }() - - // 直接等待,不使用通道,避免泄漏 - wg.Wait() - + // 2. 更新文档状态为处理中 updateDocumentReq := new(dto.UpdateDocumentReq) updateDocumentReq.Id = req.Id - - // 统一判断错误 - if err1 != nil || err2 != nil || err3 != nil { - // 更新文档状态 - updateDocumentReq.VectorStatus = document.VectorStatusFailed.Code() - if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil { - return nil, err - } - if err1 != nil { - return nil, err1 - } - if err2 != nil { - return nil, err2 - } - return nil, err3 - } - - // 4. 更新文件状态为处理中和切分数量 - if vectorDocsCount > 0 { - updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code() - } else { - updateDocumentReq.VectorStatus = document.VectorStatusCompleted.Code() - } - updateDocumentReq.ChunkCount = chunks + updateDocumentReq.VectorStatus = document.VectorStatusProcessing.Code() if _, err = dao.Document.Update(ctx, updateDocumentReq); err != nil { + // 写入任务进度失败 任务类型为文档解析 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: req.Id, + TaskType: task.TaskTypeDocParse, + Status: task.TaskStatusFailed, + Remark: "更新文档状态失败: " + err.Error(), + }) return } - - costTime := time.Since(startTime).Milliseconds() - - return &dto.ProcessDocumentRes{ - ChunkCount: chunks, - CostTime: costTime, - }, nil -} - -func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) { - // 1. 加载文件 - docs, err := s.loadDocument(ctx, doc) + // 写入任务进度进行中 任务类型为文档解析 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: req.Id, + TaskType: task.TaskTypeDocParse, + Status: task.TaskStatusRunning, + Remark: "文档解析开始", + }) if err != nil { return } - var words []gse.Keyword + // ====================== + // 核心:grpool + g.Try 最佳实践 + // ====================== + taskCtx, cancel := context.WithCancel(ctx) + + // 任务1: SQL 切分文档 + grpool.Add(taskCtx, func(ctx context.Context) { + g.TryCatch(ctx, func(ctx context.Context) { + if innerErr := s.sqlSplitDocument(ctx, doc); innerErr != nil { + cancel() + } + }, func(ctx context.Context, err error) { + cancel() + }) + }) + + // 任务2: ES 切分文档 + grpool.Add(taskCtx, func(ctx context.Context) { + g.TryCatch(ctx, func(ctx context.Context) { + if innerErr := s.esSplitDocument(ctx, doc); innerErr != nil { + cancel() + } + }, func(ctx context.Context, err error) { + cancel() + }) + }) + + // 任务3: 提取文档 + grpool.Add(taskCtx, func(ctx context.Context) { + g.TryCatch(ctx, func(ctx context.Context) { + if innerErr := s.extractDocument(ctx, doc); innerErr != nil { + cancel() + } + }, func(ctx context.Context, err error) { + cancel() + }) + }) + + return nil +} + +// extractDocument 关键词提取(支持取消) +func (s *documentService) extractDocument(ctx context.Context, doc *entity.Document) (err error) { + // ========== 取消检查 1:方法入口 ========== + if ctx.Err() != nil { + // 写入任务进度失败 任务类型为关键字存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + + // 1. 加载文件 + docs, err := s.loadDocument(ctx, doc) + if err != nil { + // 写入任务进度失败 任务类型为关键字存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusFailed, + Remark: "加载文件失败: " + err.Error(), + }) + return + } + + var words []utils.Keyword if len(docs[0].Content) < 500 { - words = gse.GseTool.Extract(docs[0].Content, 4) + words = utils.GseTool.Extract(docs[0].Content, 4) } else if len(docs[0].Content) < 2000 { - words = gse.GseTool.Extract(docs[0].Content, 8) + words = utils.GseTool.Extract(docs[0].Content, 8) } else if len(docs[0].Content) < 5000 { - words = gse.GseTool.Extract(docs[0].Content, 13) + words = utils.GseTool.Extract(docs[0].Content, 13) } else { var docsSplit []*schema.Document docsSplit, err = eino.RecursiveSplitDocument(ctx, docs) if err != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusFailed, + Remark: "递归分割文档失败: " + err.Error(), + }) return } + // ========== 取消检查 2:循环内部 ========== for _, t := range docsSplit { - words = append(words, gse.GseTool.Extract(t.Content, 6)...) + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + words = append(words, utils.GseTool.Extract(t.Content, 6)...) } } + // ========== 取消检查 3:批量操作前 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + var keywordReqs = make([]*dto.CreateKeywordReq, 0) for _, word := range words { keywordReqs = append(keywordReqs, &dto.CreateKeywordReq{ @@ -231,37 +285,111 @@ func (s *documentService) extractDocument(ctx context.Context, doc *entity.Docum if len(keywordReqs) > 0 { _, err = dao.Keyword.BatchSaveOrUpdate(ctx, keywordReqs) if err != nil { + // 写入任务进度失败 任务类型为关键字存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusFailed, + Remark: "关键字存储失败: " + err.Error(), + }) return } + // 写入任务进度已完成 任务类型为关键字存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusCompleted, + Remark: "关键字提取完成", + }) + } else { + // 写入任务进度已完成 任务类型为关键字存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeExtractKeywords, + Status: task.TaskStatusCompleted, + Remark: "没有提取到关键词,关键字提取完成", + }) } return } -func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (vectorDocsCount, docsSplitCount int64, err error) { +// sqlSplitDocument SQL切分(支持取消) +func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Document) (err error) { + // ========== 取消检查 1:方法入口 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + // 1. 加载文件 docs, err := s.loadDocument(ctx, doc) if err != nil { + // 写入任务进度失败 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "加载文件失败: " + err.Error(), + }) return } + // 2. 语义切分文件 docsSplit, err := eino.SemanticSplitDocument(ctx, docs) if err != nil { + // 写入任务进度失败 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "文档切分失败: " + err.Error(), + }) return } - docsSplitCount = gconv.Int64(len(docsSplit)) + // 2. 获取历史数据 err = s.getHistoryData(ctx, doc, public.KnowledgeLockSqlKey, public.KnowledgeContentHashSqlKey) if err != nil { + // 写入任务进度失败 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "获取历史数据失败: " + err.Error(), + }) return } + // 3. 组装向量文档 var docsChunk = make([]*schema.Document, 0) for i, t := range docsSplit { + // ========== 取消检查 2:循环内部 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + contentHash := gmd5.MustEncryptString(t.Content) - // 检查是否重复 var success bool success, err = s.checkRepeat(ctx, public.KnowledgeContentHashSqlKey, contentHash) if err != nil { + // 写入任务进度失败 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "检查重复数据失败: " + err.Error(), + }) return } if !success { @@ -277,6 +405,18 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu t.MetaData = metaData docsChunk = append(docsChunk, t) } + + // ========== 取消检查 3:批量发送前 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + // 4. 发送消息到队列 if len(docsChunk) > 0 { err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{ @@ -285,41 +425,117 @@ func (s *documentService) sqlSplitDocument(ctx context.Context, doc *entity.Docu Data: docsChunk, }, }) + if err != nil { + // 写入任务进度失败 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: "发送消息到队列失败: " + err.Error(), + }) + return + } + // 写入任务进度进行中 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusRunning, + Remark: "向量生成任务已提交到队列", + }) + } else { + // 写入任务进度已完成 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusCompleted, + Remark: "无需生成向量,任务完成", + }) } - vectorDocsCount = gconv.Int64(len(docsChunk)) return } +// esSplitDocument ES切分(支持取消) func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Document) (err error) { + // ========== 取消检查 1:方法入口 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + // 1. 加载文件 docs, err := s.loadDocument(ctx, doc) if err != nil { + // 写入任务进度失败 任务类型为es存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "加载文件失败: " + err.Error(), + }) return } + // 2. 递归切分文件 docsSplit, err := eino.RecursiveSplitDocument(ctx, docs) if err != nil { + // 写入任务进度失败 任务类型为es存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "文档切分失败: " + err.Error(), + }) return } + // 2. 获取历史数据 err = s.getHistoryData(ctx, doc, public.KnowledgeLockEsKey, public.KnowledgeContentHashEsKey) if err != nil { + // 写入任务进度失败 任务类型为es存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "获取历史数据失败: " + err.Error(), + }) return } + // 3. 组装向量文档并同时构建meilisearch文档 var meiliDocs = make([]interface{}, 0) for i, t := range docsSplit { + // ========== 取消检查 2:循环内部 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + contentHash := gmd5.MustEncryptString(t.Content) - // 检查是否重复 var success bool success, err = s.checkRepeat(ctx, public.KnowledgeContentHashEsKey, contentHash) if err != nil { + // 写入任务进度失败 任务类型为es存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "检查重复数据失败: " + err.Error(), + }) return } if !success { continue } - // 构建Meilisearch文档 meiliDocs = append(meiliDocs, map[string]interface{}{ entity.DocumentChunkCol.Id: contentHash, entity.DocumentChunkCol.DatasetId: doc.DatasetId, @@ -329,12 +545,45 @@ func (s *documentService) esSplitDocument(ctx context.Context, doc *entity.Docum entity.DocumentChunkCol.ChunkIndex: i, }) } + + // ========== 取消检查 3:批量写入前 ========== + if ctx.Err() != nil { + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "ctx取消: " + ctx.Err().Error(), + }) + return ctx.Err() + } + // 4. 写入到meilisearch数据库中 if len(meiliDocs) > 0 { if _, err = meilisearch.DB().InsertMany(ctx, meiliDocs, public.IndexNameDocumentChunk); err != nil { - g.Log().Errorf(ctx, "写入meilisearch失败: %v", err) + // 写入任务进度失败 任务类型为meilisearch存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusFailed, + Remark: "写入meilisearch失败: " + err.Error(), + }) return } + // 写入任务进度已完成 任务类型为meilisearch存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusCompleted, + Remark: "全文检索数据写入完成", + }) + } else { + // 写入任务进度已完成 任务类型为meilisearch存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: doc.Id, + TaskType: task.TaskTypeFullTextSearch, + Status: task.TaskStatusCompleted, + Remark: "无需生成全文检索数据,任务完成", + }) } return } @@ -467,20 +716,3 @@ func (s *documentService) checkRepeat(ctx context.Context, contentKey, contentHa success = val.Bool() return } - -func (s *documentService) DocsVectorStatusMsg(ctx context.Context, msg any) (err error) { - var req = new(dto.KnowledgeDocumentMsg) - if err = gconv.Struct(msg, &req); err != nil { - g.Log().Error(ctx, "DocsVectorStatusMsg err:", err) - return - } - ctx = context.WithValue(ctx, "user", &beans.User{ - TenantId: req.TenantId, - UserName: req.Creator, - }) - _, err = dao.Document.Update(ctx, &dto.UpdateDocumentReq{ - Id: req.Id, - VectorStatus: req.VectorStatus, - }) - return -} diff --git a/service/document_chunk.go b/service/document_chunk.go index d60303a..4dd3c11 100644 --- a/service/document_chunk.go +++ b/service/document_chunk.go @@ -3,15 +3,11 @@ package service import ( "context" "rag/common/eino" - "rag/consts/document" - "rag/consts/public" + "rag/common/task" "rag/dao" "rag/model/dto" "rag/model/entity" - gmq "github.com/bjang03/gmq/core/gmq" - "github.com/bjang03/gmq/mq" - "github.com/bjang03/gmq/types" "github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/schema" "github.com/gogf/gf/v2/frame/g" @@ -22,10 +18,6 @@ var DocumentChunk = new(documentChunkService) type documentChunkService struct{} -const ( - DatasetIndexStatusReady = "ready" -) - // Update 更新文件块 func (s *documentChunkService) Update(ctx context.Context, req *dto.UpdateDocumentChunkReq) (err error) { _, err = dao.DocumentChunk.Update(ctx, req) @@ -60,32 +52,29 @@ func (s *documentChunkService) DocsChunkMsg(ctx context.Context, msg any) (err e idx := eino.NewPGVectorIndexer(&eino.PGVectorIndexerOptions{ BatchSize: 10, }) + documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId]) rows, err := idx.Store(ctx, docs, indexer.WithEmbedding(eino.EmbedderDashscope)) if err != nil || rows == 0 { g.Log().Error(ctx, "DocsChunkMsg rows: , err:", rows, err) + // 写入任务进度失败 任务类型为sql存储 + remark := " 向量存储数量: " + gconv.String(rows) + if err != nil { + remark = "向量存储失败: " + err.Error() + } + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: documentId, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusFailed, + Remark: remark, + }) return } - tenantId := gconv.Uint64(docs[0].MetaData[entity.DocumentChunkCol.TenantId]) - creator := gconv.String(docs[0].MetaData[entity.DocumentChunkCol.Creator]) - documentId := gconv.Int64(docs[0].MetaData[entity.DocumentChunkCol.DocumentId]) - err = s.publishKnowledgeDocumentMsg(ctx, tenantId, creator, documentId, document.VectorStatusCompleted.Code()) - - return -} - -// publishKnowledgeDocumentMsg 发布消息 -func (s *documentChunkService) publishKnowledgeDocumentMsg(ctx context.Context, tenantId uint64, creator string, documentId int64, vectorStatus document.VectorStatus) (err error) { - knowledgeDocumentMsg := dto.KnowledgeDocumentMsg{ - TenantId: tenantId, - Creator: creator, - Id: documentId, - VectorStatus: vectorStatus, - } - err = gmq.GetGmq("primary").GmqPublish(ctx, &mq.RedisPubMessage{ - PubMessage: types.PubMessage{ - Topic: public.KnowledgeDocumentVectorStatusTopic, - Data: knowledgeDocumentMsg, - }, + // 写入任务进度成功 任务类型为sql存储 + err = Task.WriteTaskProgress(ctx, &dto.WriteTaskProgressReq{ + TaskId: documentId, + TaskType: task.TaskTypeGenerateVector, + Status: task.TaskStatusCompleted, + Remark: "向量生成完成", }) return } diff --git a/service/rag_query.go b/service/rag_query.go new file mode 100644 index 0000000..f802167 --- /dev/null +++ b/service/rag_query.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "fmt" + "rag/common/eino" + "rag/model/dto" + + "github.com/cloudwego/eino/components/retriever" + "github.com/gogf/gf/v2/os/glog" +) + +var RAGQuery = new(ragQueryService) + +type ragQueryService struct{} + +// Query 执行RAG查询 +func (s *ragQueryService) Query(ctx context.Context, req *dto.RAGQueryReq) (*dto.RAGQueryRes, error) { + if req.TopK <= 0 { + req.TopK = 5 + } + + // 4. 使用向量检索器进行查询 + r, err := eino.NewPGVectorRetriever(&eino.PGVectorRetrieverConfig{ + Embedder: eino.EmbedderDashscope, + DefaultTopK: req.TopK, + }) + if err != nil { + glog.Errorf(ctx, "初始化向量检索器失败: %v", err) + return nil, fmt.Errorf("初始化向量检索器失败: %w", err) + } + + // 5. 执行向量检索 + docs, err := r.Retrieve(ctx, req.Content, retriever.WithEmbedding(eino.EmbedderDashscope), retriever.WithDSLInfo(map[string]any{ + "dataset_ids": req.DatasetIds, + })) + if err != nil { + glog.Errorf(ctx, "向量检索失败: %v", err) + return nil, fmt.Errorf("向量检索失败: %w", err) + } + + replyMsg, sources, err := eino.NewChatModel(ctx, req.Content, docs) + if err != nil { + glog.Errorf(ctx, "向量检索失败: %v", err) + return nil, fmt.Errorf("向量检索失败: %w", err) + } + + return &dto.RAGQueryRes{ + Answer: replyMsg.Content, + Sources: sources, + }, nil +} diff --git a/service/task.go b/service/task.go new file mode 100644 index 0000000..b039497 --- /dev/null +++ b/service/task.go @@ -0,0 +1,107 @@ +package service + +import ( + "context" + "rag/dao" + "rag/model/dto" + + "rag/common/task" + + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +var Task = new(taskService) + +type taskService struct{} + +// WriteTaskProgress 写入任务进度(核心方法) +func (s *taskService) WriteTaskProgress(ctx context.Context, req *dto.WriteTaskProgressReq) (err error) { + t, total, err := dao.Task.Get(ctx, &dto.GetTaskReq{ + TaskId: req.TaskId, + }) + if err != nil { + g.Log().Errorf(ctx, "查询任务失败: %v", err) + return err + } + taskVO := make([]dto.TaskVO, 0, total) + err = gconv.Struct(t, taskVO) + if err != nil { + g.Log().Errorf(ctx, "转换任务失败: %v", err) + return err + } + taskVO = append(taskVO, dto.TaskVO{ + TaskType: req.TaskType, + Status: req.Status, + }) + completed := IsAllSubTasksCompleted(taskVO) + + // 1. 查询是否已存在该文档的该类型任务 + existTask, _, err := dao.Task.Get(ctx, &dto.GetTaskReq{ + TaskId: req.TaskId, + TaskType: req.TaskType, + }) + if err != nil { + g.Log().Errorf(ctx, "查询任务失败: %v", err) + return err + } + + // 2. 如果不存在,则创建新任务 + if g.IsEmpty(existTask) { + createReq := &dto.CreateTaskReq{ + TaskId: req.TaskId, + TaskType: req.TaskType, + Status: req.Status, + Remark: req.Remark, + } + _, err = dao.Task.Insert(ctx, createReq) + } else { + // 3. 如果已存在,则更新任务 + updateReq := &dto.UpdateTaskReq{ + Id: existTask[0].Id, + Status: req.Status, + Remark: req.Remark, + } + _, err = dao.Task.Update(ctx, updateReq) + if err != nil { + g.Log().Errorf(ctx, "更新任务失败: %v", err) + return err + } + } + + if completed { + // 3. 如果已存在,则更新任务 + _, err = dao.Task.Update(ctx, &dto.UpdateTaskReq{ + TaskId: req.TaskId, + Status: task.TaskStatusCompleted, + }) + } + + return +} + +// IsAllSubTasksCompleted 判断三个子任务是否全部完成 +// 参数:传入当前文档的所有子任务列表 +func IsAllSubTasksCompleted(subTasks []dto.TaskVO) bool { + // 必须包含 3 种任务类型 + hasKeywords := false + hasVector := false + hasFullText := false + + for _, t := range subTasks { + // 子任务必须是【已完成】状态才计数 + if t.Status == task.TaskStatusCompleted { + switch t.TaskType { + case task.TaskTypeExtractKeywords: + hasKeywords = true + case task.TaskTypeGenerateVector: + hasVector = true + case task.TaskTypeFullTextSearch: + hasFullText = true + } + } + } + + // 三个任务全部完成 → 返回true + return hasKeywords && hasVector && hasFullText +} diff --git a/update.sql b/update.sql index aebb904..69decb3 100644 --- a/update.sql +++ b/update.sql @@ -114,6 +114,7 @@ COMMENT ON COLUMN rag_knowledge_document.file_path IS '文件存储路径(如M COMMENT ON COLUMN rag_knowledge_document.metadata IS '文件元数据,结构:{"author":"作者","tags":["标签1","标签2"],"custom":{"key":"值"}}'; --------------------pgsql创建rag_knowledge_document表语句--------------------------- + --------------------pgsql创建rag_knowledge_keyword表语句--------------------------- -- 关键词表(文档关键词+权重) CREATE TABLE IF NOT EXISTS rag_knowledge_keyword ( @@ -161,6 +162,49 @@ COMMENT ON COLUMN rag_knowledge_keyword.weight IS '权重'; --------------------pgsql创建rag_knowledge_keyword表语句--------------------------- +--------------------pgsql创建rag_knowledge_task表语句--------------------------- +-- 知识库任务表 +CREATE TABLE IF NOT EXISTS rag_knowledge_task ( + -- 基础字段(完全对齐项目规范) + id BIGINT PRIMARY KEY, -- 主键ID(非自增) + tenant_id BIGINT NOT NULL DEFAULT 0, -- 租户ID int8 + creator VARCHAR(64) NOT NULL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + updater VARCHAR(64) NOT NULL, + updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + deleted_at timestamp(6), + + -- 业务字段 + task_id BIGINT NOT NULL, -- 任务ID + task_type VARCHAR(32) NOT NULL, -- 任务类型 + status VARCHAR(32) NOT NULL, -- 任务状态 + executor VARCHAR(128) DEFAULT '', -- 执行器 + remark TEXT DEFAULT '' -- 备注 + ); + +-- 索引(高频查询) +CREATE INDEX idx_rkt_tenant_id ON rag_knowledge_task(tenant_id); +CREATE INDEX idx_rkt_task_id ON rag_knowledge_task(task_id); +CREATE INDEX idx_rkt_task_type ON rag_knowledge_task(task_type); +CREATE INDEX idx_rkt_status ON rag_knowledge_task(status); +CREATE INDEX idx_rkt_deleted_at ON rag_knowledge_task(deleted_at); + +-- 表和字段注释 +COMMENT ON TABLE rag_knowledge_task IS '知识库任务表'; +COMMENT ON COLUMN rag_knowledge_task.id IS '主键ID(非自增)'; +COMMENT ON COLUMN rag_knowledge_task.tenant_id IS '租户ID'; +COMMENT ON COLUMN rag_knowledge_task.creator IS '创建人'; +COMMENT ON COLUMN rag_knowledge_task.created_at IS '创建时间'; +COMMENT ON COLUMN rag_knowledge_task.updater IS '更新人'; +COMMENT ON COLUMN rag_knowledge_task.updated_at IS '更新时间'; +COMMENT ON COLUMN rag_knowledge_task.deleted_at IS '删除时间(软删)'; +COMMENT ON COLUMN rag_knowledge_task.task_id IS '任务ID'; +COMMENT ON COLUMN rag_knowledge_task.task_type IS '任务类型'; +COMMENT ON COLUMN rag_knowledge_task.status IS '任务状态'; +COMMENT ON COLUMN rag_knowledge_task.executor IS '执行器'; +COMMENT ON COLUMN rag_knowledge_task.remark IS '备注'; + +--------------------pgsql创建rag_knowledge_task表语句--------------------------- --------------------pgsql创建rag_vector_dataset_index表语句---------------------------