126 lines
2.9 KiB
Go
126 lines
2.9 KiB
Go
package eino
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"rag/consts/model"
|
||
|
||
"github.com/cloudwego/eino/components/prompt"
|
||
"github.com/cloudwego/eino/schema"
|
||
)
|
||
|
||
const (
|
||
MaxHistoryTurns = 5 // 最大历史轮数
|
||
)
|
||
|
||
var (
|
||
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
|
||
)
|
||
|
||
func init() {
|
||
// 初始化 EINO 提示词模板
|
||
initRAGPromptTemplate()
|
||
return
|
||
}
|
||
|
||
// 初始化 EINO 官方提示词模板(最关键!)
|
||
func initRAGPromptTemplate() {
|
||
ragPromptTemplate = prompt.FromMessages(
|
||
schema.FString,
|
||
// 系统提示(带参考知识)
|
||
&schema.Message{
|
||
Role: schema.System,
|
||
Content: `你是专业客服,语气友好简洁。
|
||
请依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
|
||
|
||
参考知识:
|
||
{knowledge}`,
|
||
},
|
||
// 用户问题
|
||
&schema.Message{
|
||
Role: schema.User,
|
||
Content: "{question}",
|
||
},
|
||
)
|
||
}
|
||
|
||
// NewChatModel 只处理逻辑,不复用创建模型
|
||
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message, chatModel model.ModelConfigType) (replyMsg *schema.Message, err error) {
|
||
// 1. 构建参考知识
|
||
knowledge := buildKnowledgeAndSources(docs)
|
||
// 2. 历史精简
|
||
history = limitHistory(history)
|
||
// 3. ✅ EINO 官方模板格式化(超级干净)
|
||
msgs, err := ragPromptTemplate.Format(ctx, map[string]any{
|
||
"knowledge": knowledge,
|
||
"question": question,
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
// 4. 历史插入到模板消息中间(标准EINO用法)
|
||
if len(history) > 0 {
|
||
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
|
||
}
|
||
// 5. 🔥 直接使用全局单例,不重复创建
|
||
replyMsg, err = streamGenerateAnswer(ctx, msgs, chatModel)
|
||
|
||
return
|
||
}
|
||
|
||
func limitHistory(history []*schema.Message) []*schema.Message {
|
||
valid := make([]*schema.Message, 0, len(history))
|
||
for _, m := range history {
|
||
if m.Role == schema.User || m.Role == schema.Assistant {
|
||
valid = append(valid, m)
|
||
}
|
||
}
|
||
|
||
keep := 2 * MaxHistoryTurns
|
||
if len(valid) > keep {
|
||
valid = valid[len(valid)-keep:]
|
||
}
|
||
return valid
|
||
}
|
||
|
||
// buildKnowledgeAndSources 拼接参考知识
|
||
func buildKnowledgeAndSources(docs []*schema.Document) string {
|
||
var knowledge string
|
||
|
||
for i, doc := range docs {
|
||
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
|
||
|
||
}
|
||
return knowledge
|
||
}
|
||
|
||
// streamGenerateAnswer 流式生成
|
||
func streamGenerateAnswer(ctx context.Context, msgs []*schema.Message, chatModel model.ModelConfigType) (reply *schema.Message, err error) {
|
||
|
||
cm, err := GetTenantChatModelByType(ctx, chatModel)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
sr, err := cm.Stream(ctx, msgs)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stream failed: %w", err)
|
||
}
|
||
|
||
var chunks []*schema.Message
|
||
for {
|
||
chunk, err := sr.Recv()
|
||
if errors.Is(err, io.EOF) {
|
||
break
|
||
}
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stream recv failed: %w", err)
|
||
}
|
||
chunks = append(chunks, chunk)
|
||
}
|
||
|
||
return schema.ConcatMessages(chunks)
|
||
}
|