Files
rag/common/eino/chat_model.go

126 lines
2.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"errors"
"fmt"
"io"
"rag/consts/model"
"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/schema"
)
const (
MaxHistoryTurns = 5 // 最大历史轮数
)
var (
ragPromptTemplate prompt.ChatTemplate // EINO 官方模板
)
func init() {
// 初始化 EINO 提示词模板
initRAGPromptTemplate()
return
}
// 初始化 EINO 官方提示词模板(最关键!)
func initRAGPromptTemplate() {
ragPromptTemplate = prompt.FromMessages(
schema.FString,
// 系统提示(带参考知识)
&schema.Message{
Role: schema.System,
Content: `你是专业客服,语气友好简洁。
请依据参考知识回答,不知道就说:抱歉,我暂时无法回答这个问题。
参考知识:
{knowledge}`,
},
// 用户问题
&schema.Message{
Role: schema.User,
Content: "{question}",
},
)
}
// NewChatModel 只处理逻辑,不复用创建模型
func NewChatModel(ctx context.Context, question string, docs []*schema.Document, history []*schema.Message, chatModel model.ModelConfigType) (replyMsg *schema.Message, err error) {
// 1. 构建参考知识
knowledge := buildKnowledgeAndSources(docs)
// 2. 历史精简
history = limitHistory(history)
// 3. ✅ EINO 官方模板格式化(超级干净)
msgs, err := ragPromptTemplate.Format(ctx, map[string]any{
"knowledge": knowledge,
"question": question,
})
if err != nil {
return nil, err
}
// 4. 历史插入到模板消息中间标准EINO用法
if len(history) > 0 {
msgs = append(msgs[:1], append(history, msgs[1:]...)...)
}
// 5. 🔥 直接使用全局单例,不重复创建
replyMsg, err = streamGenerateAnswer(ctx, msgs, chatModel)
return
}
func limitHistory(history []*schema.Message) []*schema.Message {
valid := make([]*schema.Message, 0, len(history))
for _, m := range history {
if m.Role == schema.User || m.Role == schema.Assistant {
valid = append(valid, m)
}
}
keep := 2 * MaxHistoryTurns
if len(valid) > keep {
valid = valid[len(valid)-keep:]
}
return valid
}
// buildKnowledgeAndSources 拼接参考知识
func buildKnowledgeAndSources(docs []*schema.Document) string {
var knowledge string
for i, doc := range docs {
knowledge += fmt.Sprintf("[参考%d] %s\n", i+1, doc.Content)
}
return knowledge
}
// streamGenerateAnswer 流式生成
func streamGenerateAnswer(ctx context.Context, msgs []*schema.Message, chatModel model.ModelConfigType) (reply *schema.Message, err error) {
cm, err := GetTenantChatModelByType(ctx, chatModel)
if err != nil {
return nil, err
}
sr, err := cm.Stream(ctx, msgs)
if err != nil {
return nil, fmt.Errorf("stream failed: %w", err)
}
var chunks []*schema.Message
for {
chunk, err := sr.Recv()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("stream recv failed: %w", err)
}
chunks = append(chunks, chunk)
}
return schema.ConcatMessages(chunks)
}