refactor: 重构文档处理流程和任务管理

This commit is contained in:
2026-04-09 09:11:43 +08:00
parent b6896f3fb4
commit 7f894745e9
34 changed files with 1216 additions and 1056 deletions

View File

@@ -1,166 +0,0 @@
package eino
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/components/retriever"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino-ext/components/model/ark"
)
func main() {
ctx := context.Background()
// ==========================================
// 1. 初始化三大组件
// ==========================================
// 1.1 向量检索(从知识库查客服知识)
ragRetriever := NewPGVectorRetriever()
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
chatTpl := newCustomerServiceTemplate()
// 1.3 大模型ARK
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: os.Getenv("ARK_API_KEY"),
Model: os.Getenv("ARK_MODEL_ID"),
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 2. 模拟会话:从 DB 读取历史对话
// ==========================================
sessionHistory := []*schema.Message{
{Role: schema.User, Content: "你们发什么快递?"},
{Role: schema.Assistant, Content: "默认发中通快递"},
{Role: schema.User, Content: "可以发顺丰吗?"},
}
// 当前用户问题
userQuery := "那顺丰需要加钱吗?"
// ==========================================
// 3. RAG 检索知识库
// ==========================================
docs, err := ragRetriever.Retrieve(ctx, userQuery)
if err != nil {
log.Fatal(err)
}
// 拼接参考知识
knowledge := ""
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
// ==========================================
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
// ==========================================
msgs, err := chatTpl.Format(ctx, map[string]any{
"history": sessionHistory,
"knowledge": knowledge,
"question": userQuery,
})
if err != nil {
log.Fatal(err)
}
// ==========================================
// 5. 流式调用大模型生成客服回答
// ==========================================
fmt.Println("\n=== 客服回复 ===")
stream, err := chatModel.Stream(ctx, msgs)
if err != nil {
log.Fatal(err)
}
fullReply := make([]*schema.Message, 0, 100)
for {
chunk, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
log.Fatal(err)
}
fmt.Print(chunk.Content)
fullReply = append(fullReply, chunk)
}
// ==========================================
// 6. 拼接完整回复,存入 DB 作为新历史
// ==========================================
replyMsg, _ := schema.ConcatMessages(fullReply)
sessionHistory = append(sessionHistory,
&schema.Message{Role: schema.User, Content: userQuery},
replyMsg,
)
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
}
// ==========================================
// 本地客服提示词模板(不需要 MCP
// ==========================================
func newCustomerServiceTemplate() prompt.ChatTemplate {
// 系统提示 + 多轮对话 + 知识库 + 用户问题
return prompt.FromMessages(schema.Messages{
{
Role: schema.System,
Content: `你是电商智能客服,语气友好简洁。
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
参考知识:
{{.knowledge}}`,
},
// 历史对话会自动渲染在这里
{{range .history}}{{.}},{{end}},
// 当前用户问题
{Role: schema.User, Content: "{{.question}}"},
})
}
// ==========================================
// PGVector 检索器(简化可直接用)
// ==========================================
type PGVectorRetriever struct {
topK int
}
func NewPGVectorRetriever() retriever.Retriever {
return &PGVectorRetriever{topK: 3}
}
func (r *PGVectorRetriever) Retrieve(
ctx context.Context,
query string,
opts ...retriever.Option,
) ([]*schema.Document, error) {
options := retriever.GetCommonOptions(nil, opts...)
topK := r.topK
if options.TopK != nil {
topK = *options.TopK
}
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
// 模拟知识库
return []*schema.Document{
{
ID: "1",
Content: "顺丰快递需要补10元运费差价",
},
{
ID: "2",
Content: "订单满99元可免费升级顺丰",
},
}, nil
}

View File

@@ -1,107 +0,0 @@
package eino
import (
"context"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/cloudwego/eino-ext/components/indexer/es8"
)
const (
indexName = "eino_example"
fieldContent = "content"
fieldContentVector = "content_vector"
fieldExtraLocation = "location"
docExtraLocation = "location"
)
func TestIndexer() {
ctx := context.Background()
// 1. 创建 ES 客户端
client, err := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
if err != nil {
fmt.Printf("create client error: %v\n", err)
return
}
// 2. 定义 Index Spec选填如果索引不存在将自动创建
indexSpec := &es8.IndexSpec{
Settings: map[string]any{
"number_of_shards": 1,
"number_of_replicas": 0,
},
Mappings: map[string]any{
"properties": map[string]any{
fieldContentVector: map[string]any{
"type": "dense_vector",
"dims": 1024,
"index": true,
"similarity": "l2_norm",
},
},
},
}
// 4. 准备文档
// 文档通常包含 ID 和 Content
// 也可以包含额外的 Metadata 用于过滤或其他用途
docs := []*schema.Document{
{
ID: "1",
Content: "Eiffel Tower: Located in Paris, France.",
MetaData: map[string]any{
docExtraLocation: "France",
},
},
{
ID: "2",
Content: "The Great Wall: Located in China.",
MetaData: map[string]any{
docExtraLocation: "China",
},
},
}
// 5. 创建 ES 索引器组件
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
Client: client,
Index: indexName,
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
BatchSize: 10,
// DocumentToFields 指定如何将文档字段映射到 ES 字段
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
return map[string]es8.FieldValue{
fieldContent: {
Value: doc.Content,
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
},
fieldExtraLocation: {
// 额外的 metadata 字段
Value: doc.MetaData[docExtraLocation],
},
}, nil
},
// 提供 embedding 组件用于向量化
Embedding: EmbedderDashscope,
})
if err != nil {
fmt.Printf("create indexer error: %v\n", err)
return
}
// 6. 索引文档
ids, err := indexer.Store(ctx, docs)
if err != nil {
fmt.Printf("index error: %v\n", err)
return
}
fmt.Println("indexed ids:", ids)
}

View File

@@ -1,49 +0,0 @@
package eino
import (
"time"
"gitea.com/red-future/common/beans"
)
// BaseTask 任务基类 - MongoDB版本
type BaseTask struct {
beans.MongoBaseDO `bson:",inline"`
// 任务信息
TaskType TaskType `bson:"taskType" json:"taskType"`
Status TaskStatus `bson:"status" json:"status"`
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
// 进度
TotalItems int64 `bson:"totalItems" json:"totalItems"`
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
Progress float64 `bson:"progress" json:"progress"`
// 结果
StartTime *time.Time `bson:"startTime" json:"startTime"`
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
SuccessCount int64 `bson:"successCount" json:"successCount"`
FailCount int64 `bson:"failCount" json:"failCount"`
// 其他
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
}
// SQLBaseTask 任务基类 - SQL版本
type SQLBaseTask struct {
beans.SQLBaseDO
// 任务信息
TaskType TaskType `json:"taskType"`
Status TaskStatus `json:"status"`
Priority TaskPriority `json:"priority,omitempty"`
// 进度
TotalItems int64 `json:"totalItems"`
ProcessedItems int64 `json:"processedItems"`
Progress float64 `json:"progress"`
// 结果
StartTime *time.Time `json:"startTime"`
EndTime *time.Time `json:"endTime,omitempty"`
Duration int64 `json:"duration,omitempty"`
SuccessCount int64 `json:"successCount"`
FailCount int64 `json:"failCount"`
// 其他
Executor string `json:"executor,omitempty"`
}

View File

@@ -1,94 +0,0 @@
package eino
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
)
func TestRetriever() {
ctx := context.Background()
client, _ := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
// 创建 retriever 组件
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
Client: client,
Index: indexName,
TopK: 5,
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
QueryFieldName: fieldContent,
VectorFieldName: fieldContentVector,
Hybrid: false,
// RRF 仅在特定许可证下可用
// 参见: https://www.elastic.co/subscriptions
RRF: false,
RRFRankConstant: nil,
RRFWindowSize: nil,
}),
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
doc = &schema.Document{
ID: *hit.Id_,
Content: "",
MetaData: map[string]any{},
}
var src map[string]any
if err = json.Unmarshal(hit.Source_, &src); err != nil {
return nil, err
}
for field, val := range src {
switch field {
case fieldContent:
doc.Content = val.(string)
case fieldContentVector:
var v []float64
for _, item := range val.([]interface{}) {
v = append(v, item.(float64))
}
doc.WithDenseVector(v)
case fieldExtraLocation:
doc.MetaData[docExtraLocation] = val.(string)
}
}
if hit.Score_ != nil {
doc.WithScore(float64(*hit.Score_))
}
return doc, nil
},
Embedding: EmbedderDashscope,
})
// 不带过滤器的搜索
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
// 带过滤器的搜索
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
es8.WithFilters([]types.Query{{
Term: map[string]types.TermQuery{
fieldExtraLocation: {
CaseInsensitive: of(true),
Value: "China",
},
},
}}),
)
fmt.Printf("retrieved docs: %+v\n", docs)
}
func of[T any](v T) *T {
return &v
}

125
common/eino/chat_model.go Normal file
View File

@@ -0,0 +1,125 @@
package eino
import (
"context"
"errors"
"fmt"
"io"
"github.com/cloudwego/eino-ext/components/model/qwen"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/glog"
"github.com/gogf/gf/v2/util/gconv"
)
var globalChatModel *qwen.ChatModel
func init() {
ctx := context.Background()
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
var err error
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: apiKey,
Model: model,
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
MaxTokens: gconv.PtrInt(1024), // 最长回答
TopP: gconv.PtrFloat32(1.0),
})
if err != nil {
glog.Errorf(ctx, "初始化大模型失败: %v", err)
}
return
}
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
// 1. 构建参考知识
knowledge, sources := buildKnowledgeAndSources(docs)
// 2. 构建提示词
msgs, err := buildPromptMessages(ctx, knowledge, content)
if err != nil {
return
}
// 3. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
return
}
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
var knowledge string
var sources []string
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
// 提取 document_id
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
sources = append(sources, gconv.String(docID))
}
}
return knowledge, sources
}
// buildPromptMessages 构建提示词模板
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
promptTpl := prompt.FromMessages(
schema.FString,
&schema.Message{
Role: schema.System,
// Content: `你是专业的客服助手,语气友好。
//如果参考知识中有相关信息,请优先依据参考知识回答。
//如果没有相关信息,就正常回答,不要说无法回答。
//
//参考知识:
//{knowledge}`,
Content: `你是专业的客服助手,语气友好。
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
},
&schema.Message{
Role: schema.User,
Content: "{question}",
},
)
return promptTpl.Format(ctx, map[string]any{
"knowledge": knowledge,
"question": question,
})
}
// streamGenerateAnswer 流式生成
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
sr, err := chatModel.Stream(ctx, msgs)
if err != nil {
return nil, fmt.Errorf("stream failed: %w", err)
}
var chunks []*schema.Message
for {
chunk, err := sr.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("stream recv failed: %w", err)
}
chunks = append(chunks, chunk)
}
return schema.ConcatMessages(chunks)
}

View File

@@ -1,273 +0,0 @@
/*
* Copyright 2024 Red Future Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package eino
import (
"context"
"fmt"
"net/http"
"time"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/net/gclient"
"github.com/gogf/gf/v2/util/gconv"
)
var (
// 千问API默认配置
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
defaultTimeout = 10 * time.Minute
defaultRetryTimes = 2
)
type QwenEmbeddingConfig struct {
// Timeout specifies the maximum duration to wait for API responses
// Optional. Default: 10 minutes
Timeout *time.Duration `json:"timeout"`
// HTTPClient specifies the client to send HTTP requests.
// Optional. Default &http.Client{Timeout: Timeout}
HTTPClient *http.Client `json:"http_client"`
// RetryTimes specifies the number of retry attempts for failed API calls
// Optional. Default: 2
RetryTimes *int `json:"retry_times"`
// BaseURL specifies the base URL for Qwen DashScope service
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
BaseURL string `json:"base_url"`
// APIKey specifies the API Key for authentication
// Required
APIKey string `json:"api_key"`
// Model specifies the model name for Qwen embedding
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
Model string `json:"model"`
// TextType specifies the type of text: "document" or "query"
// Optional. Default: "document"
TextType string `json:"text_type"`
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
// Optional. Default: 5
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
}
type QwenEmbedder struct {
client *gclient.Client
conf *QwenEmbeddingConfig
}
// EmbeddingRequest 千问embedding请求结构
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
// EmbeddingResponse 千问embedding响应结构
type EmbeddingResponse struct {
Output struct {
Embeddings []struct {
TextIndex int `json:"text_index"`
Embedding []float64 `json:"embedding"`
} `json:"embeddings"`
} `json:"output"`
Usage struct {
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
RequestID string `json:"request_id"`
}
type APIError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
func (e *APIError) Error() string {
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
}
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
if len(config.BaseURL) == 0 {
config.BaseURL = defaultBaseURL
}
if config.Timeout == nil {
config.Timeout = &defaultTimeout
}
if config.RetryTimes == nil {
defaultRetryTimes := 2
config.RetryTimes = &defaultRetryTimes
}
if len(config.TextType) == 0 {
config.TextType = "document"
}
if config.MaxConcurrentRequests == nil {
defaultMaxConcurrentRequests := 5
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
}
client := g.Client()
client.SetTimeout(*config.Timeout)
return client
}
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
if len(config.APIKey) == 0 {
return nil, fmt.Errorf("[Qwen] APIKey is required")
}
if len(config.Model) == 0 {
return nil, fmt.Errorf("[Qwen] Model is required")
}
client := buildQwenClient(config)
return &QwenEmbedder{
client: client,
conf: config,
}, nil
}
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
[][]float64, error) {
if len(texts) == 0 {
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
}
options := embedding.GetCommonOptions(&embedding.Options{
Model: &e.conf.Model,
}, opts...)
conf := &embedding.Config{
Model: dereferenceOrZero(options.Model),
}
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
Texts: texts,
Config: conf,
})
defer func() {
if err := recover(); err != nil {
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
}
}()
var usage *embedding.TokenUsage
var embeddings [][]float64
var err error
// 调用千问API获取embedding
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
Embeddings: embeddings,
Config: conf,
TokenUsage: usage,
})
return embeddings, nil
}
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
// 构建请求
var req EmbeddingRequest
req.Model = e.conf.Model
req.Input.Texts = texts
req.Parameters.TextType = e.conf.TextType
// 调用API
client := e.client.Clone()
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
client.SetHeader("Content-Type", "application/json")
client.SetTimeout(*e.conf.Timeout)
resp, err := client.Post(ctx, e.conf.BaseURL, req)
if err != nil {
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
}
defer resp.Close()
// 检查状态码
if resp.StatusCode != http.StatusOK {
var errResp APIError
result := resp.ReadAll()
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
return nil, nil, &errResp
}
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
}
// 解析响应
var apiResp EmbeddingResponse
result := resp.ReadAll()
if err = gconv.Struct(result, &apiResp); err != nil {
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
}
// 解析响应结果
embeddings := make([][]float64, len(texts))
for _, emb := range apiResp.Output.Embeddings {
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
embeddings[emb.TextIndex] = emb.Embedding
}
}
usage := &embedding.TokenUsage{
TotalTokens: apiResp.Usage.TotalTokens,
}
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
return embeddings, usage, nil
}
func (e *QwenEmbedder) GetType() string {
return getType()
}
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
return true
}
func getType() string {
return "Qwen"
}
func dereferenceOrZero[T any](v *T) T {
if v == nil {
var t T
return t
}
return *v
}

View File

@@ -1,11 +0,0 @@
package eino
// TaskPriority 任务优先级
type TaskPriority string
const (
TaskPriorityLow TaskPriority = "low" // 低优先级
TaskPriorityMedium TaskPriority = "medium" // 中优先级
TaskPriorityHigh TaskPriority = "high" // 高优先级
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
)

View File

@@ -3,6 +3,8 @@ package eino
import (
"context"
"errors"
"rag/dao"
"sort"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components/embedding"
@@ -16,12 +18,14 @@ type PGVectorRetrieverConfig struct {
Embedder embedding.Embedder
DefaultTopK int
DefaultIndex string
DSLInfo map[string]any
}
type PGVectorRetriever struct {
embedder embedding.Embedder
topK int
index string
dslInfo map[string]any
}
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
@@ -36,43 +40,62 @@ func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever,
embedder: config.Embedder,
topK: config.DefaultTopK,
index: config.DefaultIndex,
dslInfo: config.DSLInfo,
}, nil
}
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
// 1. 处理公共 Option官方标准写法
options := &retriever.Options{
Index: &r.index,
TopK: &r.topK,
DSLInfo: r.dslInfo,
Embedding: r.embedder,
}
options = retriever.GetCommonOptions(options, opts...)
// 2. 回调(官方标准)
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
Query: query,
TopK: *options.TopK,
})
// 3. 执行检索
docs, err := r.doRetrieve(ctx, query, options)
// ==========================================
// 🔥 双路检索:向量 + 全文
// ==========================================
docsVector, err := r.doRetrieveVector(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 4. 完成回调
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
Docs: docs,
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
if err != nil {
callbacks.OnError(ctx, err)
return nil, err
}
// 合并 + 去重
docs := mergeAndDeduplicate(docsVector, docsFulltext)
// 排序distance 越小越靠前)
sort.Slice(docs, func(i, j int) bool {
d1 := gconv.Float64(docs[i].MetaData["distance"])
d2 := gconv.Float64(docs[j].MetaData["distance"])
return d1 < d2
})
// 最多保留 topK
if len(docs) > *options.TopK {
docs = docs[:*options.TopK]
}
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
return docs, nil
}
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
// 1. 生成向量
// ==========================================
// 1. 向量检索PG
// ==========================================
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
if err != nil {
return nil, err
@@ -81,37 +104,76 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *
return nil, errors.New("empty query vector")
}
queryVec := pgvector.NewVector(vectors[0])
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
// 2. PG 向量相似度检索 SQL
sql := `
SELECT id, content, dataset_id, document_id,
vector <-> ? AS distance
FROM document_chunk
ORDER BY distance ASC
LIMIT ?
`
// 3. 查询
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
if err != nil {
return nil, err
}
// 4. 转为 Eino Document
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
Metadata: map[string]any{
"dataset_id": row["dataset_id"],
"document_id": row["document_id"],
"distance": row["distance"],
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": gconv.Float64(row["distance"]),
"retrieve_by": "vector",
},
})
}
return docs, nil
}
// ==========================================
// 2. 全文检索Meilisearch🔥 新增
// ==========================================
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
topK := *opts.TopK
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
// 调用你已有的 Meilisearch DAO
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
if err != nil {
return nil, err
}
docs := make([]*schema.Document, 0, len(rows))
for _, row := range rows {
docs = append(docs, &schema.Document{
ID: gconv.String(row["id"]),
Content: gconv.String(row["content"]),
MetaData: map[string]any{
"dataset_id": gconv.Int64(row["dataset_id"]),
"document_id": gconv.Int64(row["document_id"]),
"distance": 0.1, // 全文结果给高分
"retrieve_by": "fulltext",
},
})
}
return docs, nil
}
// ==========================================
// 合并去重
// ==========================================
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
idMap := make(map[string]*schema.Document)
for _, d := range vecDocs {
idMap[d.ID] = d
}
for _, d := range fullDocs {
if _, exists := idMap[d.ID]; !exists {
idMap[d.ID] = d
}
}
merged := make([]*schema.Document, 0, len(idMap))
for _, d := range idMap {
merged = append(merged, d)
}
return merged
}

View File

@@ -1,12 +0,0 @@
package eino
// TaskStatus 任务状态
type TaskStatus string
const (
TaskStatusPending TaskStatus = "pending" // 待处理
TaskStatusRunning TaskStatus = "running" // 运行中
TaskStatusCompleted TaskStatus = "completed" // 已完成
TaskStatusFailed TaskStatus = "failed" // 失败
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
)

View File

@@ -1,14 +0,0 @@
package eino
// TaskType 任务类型
type TaskType string
const (
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
)