feat: test
This commit is contained in:
166
common/eino/a.go
Normal file
166
common/eino/a.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package eino
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/cloudwego/eino/components/prompt"
|
||||
"github.com/cloudwego/eino/components/retriever"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/cloudwego/eino-ext/components/model/ark"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
// ==========================================
|
||||
// 1. 初始化三大组件
|
||||
// ==========================================
|
||||
// 1.1 向量检索(从知识库查客服知识)
|
||||
ragRetriever := NewPGVectorRetriever()
|
||||
|
||||
// 1.2 提示词模板(客服角色 + 历史 + 知识库 + 用户问题)
|
||||
chatTpl := newCustomerServiceTemplate()
|
||||
|
||||
// 1.3 大模型(ARK)
|
||||
chatModel, err := ark.NewChatModel(ctx, &ark.ChatModelConfig{
|
||||
APIKey: os.Getenv("ARK_API_KEY"),
|
||||
Model: os.Getenv("ARK_MODEL_ID"),
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 2. 模拟会话:从 DB 读取历史对话
|
||||
// ==========================================
|
||||
sessionHistory := []*schema.Message{
|
||||
{Role: schema.User, Content: "你们发什么快递?"},
|
||||
{Role: schema.Assistant, Content: "默认发中通快递"},
|
||||
{Role: schema.User, Content: "可以发顺丰吗?"},
|
||||
}
|
||||
|
||||
// 当前用户问题
|
||||
userQuery := "那顺丰需要加钱吗?"
|
||||
|
||||
// ==========================================
|
||||
// 3. RAG 检索知识库
|
||||
// ==========================================
|
||||
docs, err := ragRetriever.Retrieve(ctx, userQuery)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// 拼接参考知识
|
||||
knowledge := ""
|
||||
for i, doc := range docs {
|
||||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 4. 模板格式化:系统提示 + 历史 + 知识 + 当前问题
|
||||
// ==========================================
|
||||
msgs, err := chatTpl.Format(ctx, map[string]any{
|
||||
"history": sessionHistory,
|
||||
"knowledge": knowledge,
|
||||
"question": userQuery,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 5. 流式调用大模型生成客服回答
|
||||
// ==========================================
|
||||
fmt.Println("\n=== 客服回复 ===")
|
||||
stream, err := chatModel.Stream(ctx, msgs)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fullReply := make([]*schema.Message, 0, 100)
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Print(chunk.Content)
|
||||
fullReply = append(fullReply, chunk)
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 6. 拼接完整回复,存入 DB 作为新历史
|
||||
// ==========================================
|
||||
replyMsg, _ := schema.ConcatMessages(fullReply)
|
||||
sessionHistory = append(sessionHistory,
|
||||
&schema.Message{Role: schema.User, Content: userQuery},
|
||||
replyMsg,
|
||||
)
|
||||
|
||||
// 接下来把 sessionHistory 存回你的 MySQL/Redis 即可
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// 本地客服提示词模板(不需要 MCP)
|
||||
// ==========================================
|
||||
func newCustomerServiceTemplate() prompt.ChatTemplate {
|
||||
// 系统提示 + 多轮对话 + 知识库 + 用户问题
|
||||
return prompt.FromMessages(schema.Messages{
|
||||
{
|
||||
Role: schema.System,
|
||||
Content: `你是电商智能客服,语气友好简洁。
|
||||
请严格根据参考知识回答,不知道就说“抱歉,这个问题我需要帮你转接人工”。
|
||||
|
||||
参考知识:
|
||||
{{.knowledge}}`,
|
||||
},
|
||||
// 历史对话会自动渲染在这里
|
||||
{{range .history}}{{.}},{{end}},
|
||||
// 当前用户问题
|
||||
{Role: schema.User, Content: "{{.question}}"},
|
||||
})
|
||||
}
|
||||
|
||||
// ==========================================
|
||||
// PGVector 检索器(简化可直接用)
|
||||
// ==========================================
|
||||
type PGVectorRetriever struct {
|
||||
topK int
|
||||
}
|
||||
|
||||
func NewPGVectorRetriever() retriever.Retriever {
|
||||
return &PGVectorRetriever{topK: 3}
|
||||
}
|
||||
|
||||
func (r *PGVectorRetriever) Retrieve(
|
||||
ctx context.Context,
|
||||
query string,
|
||||
opts ...retriever.Option,
|
||||
) ([]*schema.Document, error) {
|
||||
|
||||
options := retriever.GetCommonOptions(nil, opts...)
|
||||
topK := r.topK
|
||||
if options.TopK != nil {
|
||||
topK = *options.TopK
|
||||
}
|
||||
|
||||
// ===== 这里替换成你真实的 PG 向量检索 SQL =====
|
||||
// 模拟知识库
|
||||
return []*schema.Document{
|
||||
{
|
||||
ID: "1",
|
||||
Content: "顺丰快递需要补10元运费差价",
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
Content: "订单满99元可免费升级顺丰",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user