refactor: 重构文档处理流程和任务管理
This commit is contained in:
166
common/eino/a.go
166
common/eino/a.go
@@ -1,166 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
// ==========================================
|
||||
// 1. 初始化三大组件
|
||||
// ==========================================
|
||||
// 1.1 向量检索(从知识库查客服知识)
|
||||
ragRetriever := NewPGVectorRetriever()
|
||||
|
||||
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
|
||||
chatTpl := newCustomerServiceTemplate()
|
||||
|
||||
// 1.3 大模型(ARK)
|
||||
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
||||
APIKey: os.Getenv("ARK_API_KEY"),
|
||||
Model: os.Getenv("ARK_MODEL_ID"),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 2. 模拟会话:从 DB 读取历史对话
|
||||
// ==========================================
|
||||
sessionHistory := []*schema.Message{
|
||||
{Role: schema.User, Content: "你们发什么快递?"},
|
||||
{Role: schema.Assistant, Content: "默认发中通快递"},
|
||||
{Role: schema.User, Content: "可以发顺丰吗?"},
|
||||
}
|
||||
|
||||
// 当前用户问题
|
||||
userQuery := "那顺丰需要加钱吗?"
|
||||
|
||||
// ==========================================
|
||||
// 3. RAG 检索知识库
|
||||
// ==========================================
|
||||
docs, err := ragRetriever.Retrieve(ctx, userQuery)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 拼接参考知识
|
||||
knowledge := ""
|
||||
for i, doc := range docs {
|
||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
|
||||
// ==========================================
|
||||
msgs, err := chatTpl.Format(ctx, map[string]any{
|
||||
"history": sessionHistory,
|
||||
"knowledge": knowledge,
|
||||
"question": userQuery,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 5. 流式调用大模型生成客服回答
|
||||
// ==========================================
|
||||
fmt.Println("\n=== 客服回复 ===")
|
||||
stream, err := chatModel.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fullReply := make([]*schema.Message, 0, 100)
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Print(chunk.Content)
|
||||
fullReply = append(fullReply, chunk)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 6. 拼接完整回复,存入 DB 作为新历史
|
||||
// ==========================================
|
||||
replyMsg, _ := schema.ConcatMessages(fullReply)
|
||||
sessionHistory = append(sessionHistory,
|
||||
&schema.Message{Role: schema.User, Content: userQuery},
|
||||
replyMsg,
|
||||
)
|
||||
|
||||
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 本地客服提示词模板(不需要 MCP)
|
||||
// ==========================================
|
||||
func newCustomerServiceTemplate() prompt.ChatTemplate {
|
||||
// 系统提示 + 多轮对话 + 知识库 + 用户问题
|
||||
return prompt.FromMessages(schema.Messages{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: `你是电商智能客服,语气友好简洁。
|
||||
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
|
||||
|
||||
参考知识:
|
||||
{{.knowledge}}`,
|
||||
},
|
||||
// 历史对话会自动渲染在这里
|
||||
{{range .history}}{{.}},{{end}},
|
||||
// 当前用户问题
|
||||
{Role: schema.User, Content: "{{.question}}"},
|
||||
})
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// PGVector 检索器(简化可直接用)
|
||||
// ==========================================
|
||||
type PGVectorRetriever struct {
|
||||
topK int
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever() retriever.Retriever {
|
||||
return &PGVectorRetriever{topK: 3}
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
opts ...retriever.Option,
|
||||
) ([]*schema.Document, error) {
|
||||
|
||||
options := retriever.GetCommonOptions(nil, opts...)
|
||||
topK := r.topK
|
||||
if options.TopK != nil {
|
||||
topK = *options.TopK
|
||||
}
|
||||
|
||||
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
|
||||
// 模拟知识库
|
||||
return []*schema.Document{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "顺丰快递需要补10元运费差价",
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "订单满99元可免费升级顺丰",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
107
common/eino/b.go
107
common/eino/b.go
@@ -1,107 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/indexer/es8"
|
||||
)
|
||||
|
||||
const (
|
||||
indexName = "eino_example"
|
||||
fieldContent = "content"
|
||||
fieldContentVector = "content_vector"
|
||||
fieldExtraLocation = "location"
|
||||
docExtraLocation = "location"
|
||||
)
|
||||
|
||||
func TestIndexer() {
|
||||
ctx := context.Background()
|
||||
|
||||
// 1. 创建 ES 客户端
|
||||
client, err := elasticsearch.NewClient(elasticsearch.Config{
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("create client error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 定义 Index Spec(选填:如果索引不存在,将自动创建)
|
||||
indexSpec := &es8.IndexSpec{
|
||||
Settings: map[string]any{
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0,
|
||||
},
|
||||
Mappings: map[string]any{
|
||||
"properties": map[string]any{
|
||||
fieldContentVector: map[string]any{
|
||||
"type": "dense_vector",
|
||||
"dims": 1024,
|
||||
"index": true,
|
||||
"similarity": "l2_norm",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 4. 准备文档
|
||||
// 文档通常包含 ID 和 Content
|
||||
// 也可以包含额外的 Metadata 用于过滤或其他用途
|
||||
docs := []*schema.Document{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "Eiffel Tower: Located in Paris, France.",
|
||||
MetaData: map[string]any{
|
||||
docExtraLocation: "France",
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "The Great Wall: Located in China.",
|
||||
MetaData: map[string]any{
|
||||
docExtraLocation: "China",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 5. 创建 ES 索引器组件
|
||||
indexer, err := es8.NewIndexer(ctx, &es8.IndexerConfig{
|
||||
Client: client,
|
||||
Index: indexName,
|
||||
IndexSpec: indexSpec, // 添加此项以启用自动索引创建
|
||||
BatchSize: 10,
|
||||
// DocumentToFields 指定如何将文档字段映射到 ES 字段
|
||||
DocumentToFields: func(ctx context.Context, doc *schema.Document) (field2Value map[string]es8.FieldValue, err error) {
|
||||
return map[string]es8.FieldValue{
|
||||
fieldContent: {
|
||||
Value: doc.Content,
|
||||
EmbedKey: fieldContentVector, // 对文档内容进行向量化并保存到 "content_vector" 字段
|
||||
},
|
||||
fieldExtraLocation: {
|
||||
// 额外的 metadata 字段
|
||||
Value: doc.MetaData[docExtraLocation],
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
// 提供 embedding 组件用于向量化
|
||||
Embedding: EmbedderDashscope,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("create indexer error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 6. 索引文档
|
||||
ids, err := indexer.Store(ctx, docs)
|
||||
if err != nil {
|
||||
fmt.Printf("index error: %v\n", err)
|
||||
return
|
||||
}
|
||||
fmt.Println("indexed ids:", ids)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
)
|
||||
|
||||
// BaseTask 任务基类 - MongoDB版本
|
||||
type BaseTask struct {
|
||||
beans.MongoBaseDO `bson:",inline"`
|
||||
// 任务信息
|
||||
TaskType TaskType `bson:"taskType" json:"taskType"`
|
||||
Status TaskStatus `bson:"status" json:"status"`
|
||||
Priority TaskPriority `bson:"priority,omitempty" json:"priority,omitempty"`
|
||||
// 进度
|
||||
TotalItems int64 `bson:"totalItems" json:"totalItems"`
|
||||
ProcessedItems int64 `bson:"processedItems" json:"processedItems"`
|
||||
Progress float64 `bson:"progress" json:"progress"`
|
||||
// 结果
|
||||
StartTime *time.Time `bson:"startTime" json:"startTime"`
|
||||
EndTime *time.Time `bson:"endTime,omitempty" json:"endTime,omitempty"`
|
||||
Duration int64 `bson:"duration,omitempty" json:"duration,omitempty"`
|
||||
SuccessCount int64 `bson:"successCount" json:"successCount"`
|
||||
FailCount int64 `bson:"failCount" json:"failCount"`
|
||||
// 其他
|
||||
Executor string `bson:"executor,omitempty" json:"executor,omitempty"`
|
||||
}
|
||||
|
||||
// SQLBaseTask 任务基类 - SQL版本
|
||||
type SQLBaseTask struct {
|
||||
beans.SQLBaseDO
|
||||
// 任务信息
|
||||
TaskType TaskType `json:"taskType"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Priority TaskPriority `json:"priority,omitempty"`
|
||||
// 进度
|
||||
TotalItems int64 `json:"totalItems"`
|
||||
ProcessedItems int64 `json:"processedItems"`
|
||||
Progress float64 `json:"progress"`
|
||||
// 结果
|
||||
StartTime *time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration int64 `json:"duration,omitempty"`
|
||||
SuccessCount int64 `json:"successCount"`
|
||||
FailCount int64 `json:"failCount"`
|
||||
// 其他
|
||||
Executor string `json:"executor,omitempty"`
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/elastic/go-elasticsearch/v8"
|
||||
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/retriever/es8"
|
||||
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
|
||||
)
|
||||
|
||||
func TestRetriever() {
|
||||
ctx := context.Background()
|
||||
|
||||
client, _ := elasticsearch.NewClient(elasticsearch.Config{
|
||||
Addresses: []string{"http://localhost:9200"},
|
||||
})
|
||||
|
||||
// 创建 retriever 组件
|
||||
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
|
||||
Client: client,
|
||||
Index: indexName,
|
||||
TopK: 5,
|
||||
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
|
||||
QueryFieldName: fieldContent,
|
||||
VectorFieldName: fieldContentVector,
|
||||
Hybrid: false,
|
||||
// RRF 仅在特定许可证下可用
|
||||
// 参见: https://www.elastic.co/subscriptions
|
||||
RRF: false,
|
||||
RRFRankConstant: nil,
|
||||
RRFWindowSize: nil,
|
||||
}),
|
||||
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
|
||||
doc = &schema.Document{
|
||||
ID: *hit.Id_,
|
||||
Content: "",
|
||||
MetaData: map[string]any{},
|
||||
}
|
||||
|
||||
var src map[string]any
|
||||
if err = json.Unmarshal(hit.Source_, &src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for field, val := range src {
|
||||
switch field {
|
||||
case fieldContent:
|
||||
doc.Content = val.(string)
|
||||
case fieldContentVector:
|
||||
var v []float64
|
||||
for _, item := range val.([]interface{}) {
|
||||
v = append(v, item.(float64))
|
||||
}
|
||||
doc.WithDenseVector(v)
|
||||
case fieldExtraLocation:
|
||||
doc.MetaData[docExtraLocation] = val.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if hit.Score_ != nil {
|
||||
doc.WithScore(float64(*hit.Score_))
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
},
|
||||
Embedding: EmbedderDashscope,
|
||||
})
|
||||
|
||||
// 不带过滤器的搜索
|
||||
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
|
||||
|
||||
// 带过滤器的搜索
|
||||
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
|
||||
es8.WithFilters([]types.Query{{
|
||||
Term: map[string]types.TermQuery{
|
||||
fieldExtraLocation: {
|
||||
CaseInsensitive: of(true),
|
||||
Value: "China",
|
||||
},
|
||||
},
|
||||
}}),
|
||||
)
|
||||
|
||||
fmt.Printf("retrieved docs: %+v\n", docs)
|
||||
}
|
||||
|
||||
func of[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
125
common/eino/chat_model.go
Normal file
125
common/eino/chat_model.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/qwen"
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var globalChatModel *qwen.ChatModel
|
||||
|
||||
func init() {
|
||||
ctx := context.Background()
|
||||
|
||||
apiKey := g.Cfg().MustGet(ctx, "eino.chatmodel.apiKey").String()
|
||||
model := g.Cfg().MustGet(ctx, "eino.chatmodel.model").String()
|
||||
|
||||
var err error
|
||||
globalChatModel, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
BaseURL: "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
Temperature: gconv.PtrFloat32(0.7), // 客服最佳
|
||||
MaxTokens: gconv.PtrInt(1024), // 最长回答
|
||||
TopP: gconv.PtrFloat32(1.0),
|
||||
})
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "初始化大模型失败: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// NewChatModel 只处理逻辑,不复用创建模型
|
||||
func NewChatModel(ctx context.Context, content string, docs []*schema.Document) (replyMsg *schema.Message, sources []string, err error) {
|
||||
// 1. 构建参考知识
|
||||
knowledge, sources := buildKnowledgeAndSources(docs)
|
||||
|
||||
// 2. 构建提示词
|
||||
msgs, err := buildPromptMessages(ctx, knowledge, content)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 🔥 直接使用全局单例,不重复创建
|
||||
replyMsg, err = streamGenerateAnswer(ctx, globalChatModel, msgs)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// buildKnowledgeAndSources 拼接参考知识 + 提取文档来源
|
||||
func buildKnowledgeAndSources(docs []*schema.Document) (string, []string) {
|
||||
var knowledge string
|
||||
var sources []string
|
||||
|
||||
for i, doc := range docs {
|
||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||
|
||||
// 提取 document_id
|
||||
if docID, ok := doc.MetaData["document_id"].(int64); ok && docID > 0 {
|
||||
sources = append(sources, gconv.String(docID))
|
||||
}
|
||||
}
|
||||
return knowledge, sources
|
||||
}
|
||||
|
||||
// buildPromptMessages 构建提示词模板
|
||||
func buildPromptMessages(ctx context.Context, knowledge string, question string) (msgs []*schema.Message, err error) {
|
||||
promptTpl := prompt.FromMessages(
|
||||
schema.FString,
|
||||
&schema.Message{
|
||||
Role: schema.System,
|
||||
// Content: `你是专业的客服助手,语气友好。
|
||||
//如果参考知识中有相关信息,请优先依据参考知识回答。
|
||||
//如果没有相关信息,就正常回答,不要说无法回答。
|
||||
//
|
||||
//参考知识:
|
||||
//{knowledge}`,
|
||||
Content: `你是专业的客服助手,语气友好。
|
||||
请根据参考知识回答用户问题,无法回答则说:抱歉,我暂时无法回答这个问题。
|
||||
|
||||
参考知识:
|
||||
{knowledge}`,
|
||||
},
|
||||
&schema.Message{
|
||||
Role: schema.User,
|
||||
Content: "{question}",
|
||||
},
|
||||
)
|
||||
|
||||
return promptTpl.Format(ctx, map[string]any{
|
||||
"knowledge": knowledge,
|
||||
"question": question,
|
||||
})
|
||||
}
|
||||
|
||||
// streamGenerateAnswer 流式生成
|
||||
func streamGenerateAnswer(ctx context.Context, chatModel *qwen.ChatModel, msgs []*schema.Message) (reply *schema.Message, err error) {
|
||||
|
||||
sr, err := chatModel.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream failed: %w", err)
|
||||
}
|
||||
|
||||
var chunks []*schema.Message
|
||||
for {
|
||||
chunk, err := sr.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stream recv failed: %w", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return schema.ConcatMessages(chunks)
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
/*
|
||||
* Copyright 2024 Red Future Authors
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/net/gclient"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var (
|
||||
// 千问API默认配置
|
||||
defaultBaseURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
||||
defaultTimeout = 10 * time.Minute
|
||||
defaultRetryTimes = 2
|
||||
)
|
||||
|
||||
type QwenEmbeddingConfig struct {
|
||||
// Timeout specifies the maximum duration to wait for API responses
|
||||
// Optional. Default: 10 minutes
|
||||
Timeout *time.Duration `json:"timeout"`
|
||||
|
||||
// HTTPClient specifies the client to send HTTP requests.
|
||||
// Optional. Default &http.Client{Timeout: Timeout}
|
||||
HTTPClient *http.Client `json:"http_client"`
|
||||
|
||||
// RetryTimes specifies the number of retry attempts for failed API calls
|
||||
// Optional. Default: 2
|
||||
RetryTimes *int `json:"retry_times"`
|
||||
|
||||
// BaseURL specifies the base URL for Qwen DashScope service
|
||||
// Optional. Default: "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding"
|
||||
BaseURL string `json:"base_url"`
|
||||
|
||||
// APIKey specifies the API Key for authentication
|
||||
// Required
|
||||
APIKey string `json:"api_key"`
|
||||
|
||||
// Model specifies the model name for Qwen embedding
|
||||
// Required. Examples: "text-embedding-v2", "text-embedding-v3"
|
||||
Model string `json:"model"`
|
||||
|
||||
// TextType specifies the type of text: "document" or "query"
|
||||
// Optional. Default: "document"
|
||||
TextType string `json:"text_type"`
|
||||
|
||||
// MaxConcurrentRequests specifies the maximum number of concurrent requests allowed
|
||||
// Optional. Default: 5
|
||||
MaxConcurrentRequests *int `json:"max_concurrent_requests"`
|
||||
}
|
||||
|
||||
type QwenEmbedder struct {
|
||||
client *gclient.Client
|
||||
conf *QwenEmbeddingConfig
|
||||
}
|
||||
|
||||
// EmbeddingRequest 千问embedding请求结构
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
// EmbeddingResponse 千问embedding响应结构
|
||||
type EmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []struct {
|
||||
TextIndex int `json:"text_index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
} `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
type APIError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestID string `json:"request_id"`
|
||||
}
|
||||
|
||||
func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("API Error: %s - %s (RequestID: %s)", e.Code, e.Message, e.RequestID)
|
||||
}
|
||||
|
||||
func buildQwenClient(config *QwenEmbeddingConfig) *gclient.Client {
|
||||
if len(config.BaseURL) == 0 {
|
||||
config.BaseURL = defaultBaseURL
|
||||
}
|
||||
if config.Timeout == nil {
|
||||
config.Timeout = &defaultTimeout
|
||||
}
|
||||
if config.RetryTimes == nil {
|
||||
defaultRetryTimes := 2
|
||||
config.RetryTimes = &defaultRetryTimes
|
||||
}
|
||||
if len(config.TextType) == 0 {
|
||||
config.TextType = "document"
|
||||
}
|
||||
if config.MaxConcurrentRequests == nil {
|
||||
defaultMaxConcurrentRequests := 5
|
||||
config.MaxConcurrentRequests = &defaultMaxConcurrentRequests
|
||||
}
|
||||
|
||||
client := g.Client()
|
||||
client.SetTimeout(*config.Timeout)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
func NewQwenEmbedder(ctx context.Context, config *QwenEmbeddingConfig) (*QwenEmbedder, error) {
|
||||
if len(config.APIKey) == 0 {
|
||||
return nil, fmt.Errorf("[Qwen] APIKey is required")
|
||||
}
|
||||
if len(config.Model) == 0 {
|
||||
return nil, fmt.Errorf("[Qwen] Model is required")
|
||||
}
|
||||
|
||||
client := buildQwenClient(config)
|
||||
|
||||
return &QwenEmbedder{
|
||||
client: client,
|
||||
conf: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) (
|
||||
[][]float64, error) {
|
||||
|
||||
if len(texts) == 0 {
|
||||
return nil, fmt.Errorf("[Qwen] texts cannot be empty")
|
||||
}
|
||||
|
||||
options := embedding.GetCommonOptions(&embedding.Options{
|
||||
Model: &e.conf.Model,
|
||||
}, opts...)
|
||||
|
||||
conf := &embedding.Config{
|
||||
Model: dereferenceOrZero(options.Model),
|
||||
}
|
||||
|
||||
ctx = callbacks.EnsureRunInfo(ctx, e.GetType(), components.ComponentOfEmbedding)
|
||||
ctx = callbacks.OnStart(ctx, &embedding.CallbackInput{
|
||||
Texts: texts,
|
||||
Config: conf,
|
||||
})
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
callbacks.OnError(ctx, fmt.Errorf("[Qwen] panic: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
var usage *embedding.TokenUsage
|
||||
var embeddings [][]float64
|
||||
var err error
|
||||
|
||||
// 调用千问API获取embedding
|
||||
embeddings, usage, err = e.callEmbeddingAPI(ctx, texts)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &embedding.CallbackOutput{
|
||||
Embeddings: embeddings,
|
||||
Config: conf,
|
||||
TokenUsage: usage,
|
||||
})
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) callEmbeddingAPI(ctx context.Context, texts []string) ([][]float64, *embedding.TokenUsage, error) {
|
||||
// 构建请求
|
||||
var req EmbeddingRequest
|
||||
req.Model = e.conf.Model
|
||||
req.Input.Texts = texts
|
||||
req.Parameters.TextType = e.conf.TextType
|
||||
|
||||
// 调用API
|
||||
client := e.client.Clone()
|
||||
client.SetHeader("Authorization", "Bearer "+e.conf.APIKey)
|
||||
client.SetHeader("Content-Type", "application/json")
|
||||
client.SetTimeout(*e.conf.Timeout)
|
||||
|
||||
resp, err := client.Post(ctx, e.conf.BaseURL, req)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("[Qwen] HTTP request error: %w", err)
|
||||
}
|
||||
|
||||
defer resp.Close()
|
||||
|
||||
// 检查状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp APIError
|
||||
result := resp.ReadAll()
|
||||
if err = gconv.Struct(result, &errResp); err == nil && errResp.Code != "" {
|
||||
return nil, nil, &errResp
|
||||
}
|
||||
return nil, nil, fmt.Errorf("[Qwen] HTTP status error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var apiResp EmbeddingResponse
|
||||
result := resp.ReadAll()
|
||||
if err = gconv.Struct(result, &apiResp); err != nil {
|
||||
return nil, nil, fmt.Errorf("[Qwen] parse response error: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应结果
|
||||
embeddings := make([][]float64, len(texts))
|
||||
for _, emb := range apiResp.Output.Embeddings {
|
||||
if emb.TextIndex >= 0 && emb.TextIndex < len(embeddings) {
|
||||
embeddings[emb.TextIndex] = emb.Embedding
|
||||
}
|
||||
}
|
||||
|
||||
usage := &embedding.TokenUsage{
|
||||
TotalTokens: apiResp.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
g.Log().Debugf(ctx, "[Qwen] Embedding success: request_id=%s, total_tokens=%d", apiResp.RequestID, usage.TotalTokens)
|
||||
|
||||
return embeddings, usage, nil
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) GetType() string {
|
||||
return getType()
|
||||
}
|
||||
|
||||
func (e *QwenEmbedder) IsCallbacksEnabled() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func getType() string {
|
||||
return "Qwen"
|
||||
}
|
||||
|
||||
func dereferenceOrZero[T any](v *T) T {
|
||||
if v == nil {
|
||||
var t T
|
||||
return t
|
||||
}
|
||||
return *v
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package eino
|
||||
|
||||
// TaskPriority 任务优先级
|
||||
type TaskPriority string
|
||||
|
||||
const (
|
||||
TaskPriorityLow TaskPriority = "low" // 低优先级
|
||||
TaskPriorityMedium TaskPriority = "medium" // 中优先级
|
||||
TaskPriorityHigh TaskPriority = "high" // 高优先级
|
||||
TaskPriorityUrgent TaskPriority = "urgent" // 紧急
|
||||
)
|
||||
@@ -3,6 +3,8 @@ package eino
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"rag/dao"
|
||||
"sort"
|
||||
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components/embedding"
|
||||
@@ -16,12 +18,14 @@ type PGVectorRetrieverConfig struct {
|
||||
Embedder embedding.Embedder
|
||||
DefaultTopK int
|
||||
DefaultIndex string
|
||||
DSLInfo map[string]any
|
||||
}
|
||||
|
||||
type PGVectorRetriever struct {
|
||||
embedder embedding.Embedder
|
||||
topK int
|
||||
index string
|
||||
dslInfo map[string]any
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever, error) {
|
||||
@@ -36,43 +40,62 @@ func NewPGVectorRetriever(config *PGVectorRetrieverConfig) (*PGVectorRetriever,
|
||||
embedder: config.Embedder,
|
||||
topK: config.DefaultTopK,
|
||||
index: config.DefaultIndex,
|
||||
dslInfo: config.DSLInfo,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
|
||||
|
||||
// 1. 处理公共 Option(官方标准写法)
|
||||
options := &retriever.Options{
|
||||
Index: &r.index,
|
||||
TopK: &r.topK,
|
||||
DSLInfo: r.dslInfo,
|
||||
Embedding: r.embedder,
|
||||
}
|
||||
options = retriever.GetCommonOptions(options, opts...)
|
||||
|
||||
// 2. 回调(官方标准)
|
||||
ctx = callbacks.OnStart(ctx, &retriever.CallbackInput{
|
||||
Query: query,
|
||||
TopK: *options.TopK,
|
||||
})
|
||||
|
||||
// 3. 执行检索
|
||||
docs, err := r.doRetrieve(ctx, query, options)
|
||||
// ==========================================
|
||||
// 🔥 双路检索:向量 + 全文
|
||||
// ==========================================
|
||||
docsVector, err := r.doRetrieveVector(ctx, query, options)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 完成回调
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{
|
||||
Docs: docs,
|
||||
docsFulltext, err := r.doRetrieveMeilisearch(ctx, query, options)
|
||||
if err != nil {
|
||||
callbacks.OnError(ctx, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 合并 + 去重
|
||||
docs := mergeAndDeduplicate(docsVector, docsFulltext)
|
||||
|
||||
// 排序(distance 越小越靠前)
|
||||
sort.Slice(docs, func(i, j int) bool {
|
||||
d1 := gconv.Float64(docs[i].MetaData["distance"])
|
||||
d2 := gconv.Float64(docs[j].MetaData["distance"])
|
||||
return d1 < d2
|
||||
})
|
||||
|
||||
// 最多保留 topK
|
||||
if len(docs) > *options.TopK {
|
||||
docs = docs[:*options.TopK]
|
||||
}
|
||||
|
||||
callbacks.OnEnd(ctx, &retriever.CallbackOutput{Docs: docs})
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
|
||||
// 1. 生成向量
|
||||
// ==========================================
|
||||
// 1. 向量检索(PG)
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveVector(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
vectors, err := opts.Embedding.EmbedStrings(ctx, []string{query})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -81,37 +104,76 @@ func (r *PGVectorRetriever) doRetrieve(ctx context.Context, query string, opts *
|
||||
return nil, errors.New("empty query vector")
|
||||
}
|
||||
|
||||
queryVec := pgvector.NewVector(vectors[0])
|
||||
queryVec := pgvector.NewVector(gconv.Float32s(vectors[0]))
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
// 2. PG 向量相似度检索 SQL
|
||||
sql := `
|
||||
SELECT id, content, dataset_id, document_id,
|
||||
vector <-> ? AS distance
|
||||
FROM document_chunk
|
||||
ORDER BY distance ASC
|
||||
LIMIT ?
|
||||
`
|
||||
|
||||
// 3. 查询
|
||||
rows, err := dao.DocumentChunk.GetDB().GetAll(ctx, sql, queryVec, topK)
|
||||
rows, err := dao.DocumentChunk.GetAllByVector(ctx, datasetIds, queryVec, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. 转为 Eino Document
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
Metadata: map[string]any{
|
||||
"dataset_id": row["dataset_id"],
|
||||
"document_id": row["document_id"],
|
||||
"distance": row["distance"],
|
||||
MetaData: map[string]any{
|
||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||
"document_id": gconv.Int64(row["document_id"]),
|
||||
"distance": gconv.Float64(row["distance"]),
|
||||
"retrieve_by": "vector",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 2. 全文检索(Meilisearch)🔥 新增
|
||||
// ==========================================
|
||||
func (r *PGVectorRetriever) doRetrieveMeilisearch(ctx context.Context, query string, opts *retriever.Options) ([]*schema.Document, error) {
|
||||
topK := *opts.TopK
|
||||
datasetIds := gconv.Int64s(opts.DSLInfo["dataset_ids"])
|
||||
|
||||
// 调用你已有的 Meilisearch DAO
|
||||
rows, err := dao.DocumentChunk.SearchByKeywords(ctx, query, datasetIds, topK)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
docs := make([]*schema.Document, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
docs = append(docs, &schema.Document{
|
||||
ID: gconv.String(row["id"]),
|
||||
Content: gconv.String(row["content"]),
|
||||
MetaData: map[string]any{
|
||||
"dataset_id": gconv.Int64(row["dataset_id"]),
|
||||
"document_id": gconv.Int64(row["document_id"]),
|
||||
"distance": 0.1, // 全文结果给高分
|
||||
"retrieve_by": "fulltext",
|
||||
},
|
||||
})
|
||||
}
|
||||
return docs, nil
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 合并去重
|
||||
// ==========================================
|
||||
func mergeAndDeduplicate(vecDocs, fullDocs []*schema.Document) []*schema.Document {
|
||||
idMap := make(map[string]*schema.Document)
|
||||
for _, d := range vecDocs {
|
||||
idMap[d.ID] = d
|
||||
}
|
||||
for _, d := range fullDocs {
|
||||
if _, exists := idMap[d.ID]; !exists {
|
||||
idMap[d.ID] = d
|
||||
}
|
||||
}
|
||||
merged := make([]*schema.Document, 0, len(idMap))
|
||||
for _, d := range idMap {
|
||||
merged = append(merged, d)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package eino
|
||||
|
||||
// TaskStatus 任务状态
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusPending TaskStatus = "pending" // 待处理
|
||||
TaskStatusRunning TaskStatus = "running" // 运行中
|
||||
TaskStatusCompleted TaskStatus = "completed" // 已完成
|
||||
TaskStatusFailed TaskStatus = "failed" // 失败
|
||||
TaskStatusCancelled TaskStatus = "cancelled" // 已取消
|
||||
)
|
||||
@@ -1,14 +0,0 @@
|
||||
package eino
|
||||
|
||||
// TaskType 任务类型
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
TaskTypeDocumentIngestion TaskType = "document_ingestion" // 文档摄入任务
|
||||
TaskTypeVectorIngestion TaskType = "vector_ingestion" // 向量摄入任务
|
||||
TaskTypeIndexCreation TaskType = "index_creation" // 索引创建任务
|
||||
TaskTypeQAProcessing TaskType = "qa_processing" // 问答处理任务
|
||||
TaskTypeKnowledgeConstruction TaskType = "knowledge_construction" // 知识库构建任务
|
||||
TaskTypeGraphBuilding TaskType = "graph_building" // 图谱构建任务
|
||||
TaskTypeKnowledgeSync TaskType = "knowledge_sync" // 知识同步任务
|
||||
)
|
||||
Reference in New Issue
Block a user