Files
rag/common/eino/chat.go

244 lines
7.6 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/dao"
"rag/model/dto"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/model/ark"
"github.com/cloudwego/eino-ext/components/model/arkbot"
"github.com/cloudwego/eino-ext/components/model/claude"
"github.com/cloudwego/eino-ext/components/model/deepseek"
"github.com/cloudwego/eino-ext/components/model/ollama"
"github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino-ext/components/model/qianfan"
"github.com/cloudwego/eino-ext/components/model/qwen"
modelChat "github.com/cloudwego/eino/components/model"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
type ChatModelSet struct {
Ark *ark.ChatModel
ArkBot *arkbot.ChatModel
Claude *claude.ChatModel
DeepSeek *deepseek.ChatModel
Ollama *ollama.ChatModel
OpenAI *openai.ChatModel
Qianfan *qianfan.ChatModel
Qwen *qwen.ChatModel
}
// 全局租户容器key=tenantIdvalue=该租户的对话模型
var tenantChatModels = make(map[uint64]*ChatModelSet)
func init() {
ctx := context.Background()
ctx, span := jaeger.NewSpan(ctx, "InitAllChat")
defer span.End()
InitAllChat(ctx)
return
}
// ===================== 1. 服务启动时:初始化所有租户对话模型 =====================
func InitAllChat(ctx context.Context) {
list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
ModelType: model.ModelTypeChat.Code(),
})
if err != nil {
g.Log().Errorf(ctx, "获取所有租户对话模型失败: %v", err)
return
}
for _, l := range list {
err = InitChat(ctx, l)
if err != nil {
g.Log().Errorf(ctx, "初始化租户[%v]的对话模型失败: %v", l.TenantId, err)
continue
}
}
}
func InitChat(ctx context.Context, modelDO *entity.Model) (err error) {
set := &ChatModelSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeChatArk.Code():
var cfg entity.ChatModelConfigArk
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Ark配置失败: %v", err)
}
set.Ark, err = ark.NewChatModel(ctx, &ark.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatArkBot.Code():
var cfg entity.ChatModelConfigArkBot
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析ArkBot配置失败: %v", err)
}
set.ArkBot, err = arkbot.NewChatModel(ctx, &arkbot.Config{
APIKey: cfg.APIKey,
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatClaude.Code():
var cfg entity.ChatModelConfigClaude
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Claude配置失败: %v", err)
}
claudeCfg := claude.Config{
APIKey: cfg.APIKey,
BaseURL: gconv.PtrString(cfg.BaseURL),
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.Int(1024),
TopP: gconv.PtrFloat32(1.0),
ByBedrock: cfg.ByBedrock,
AccessKey: cfg.AccessKey,
SecretAccessKey: cfg.SecretAccessKey,
Region: cfg.Region,
}
set.Claude, err = claude.NewChatModel(ctx, &claudeCfg)
case *model.ModelConfigTypeChatDeepSeek.Code():
var cfg entity.ChatModelConfigDeepSeek
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析DeepSeek配置失败: %v", err)
}
set.DeepSeek, err = deepseek.NewChatModel(ctx, &deepseek.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
Temperature: gconv.Float32(0.7),
MaxTokens: gconv.Int(1024),
TopP: gconv.Float32(1.0),
})
case *model.ModelConfigTypeChatOllama.Code():
var cfg entity.ChatModelConfigOllama
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Ollama配置失败: %v", err)
}
set.Ollama, err = ollama.NewChatModel(ctx, &ollama.ChatModelConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeChatOpenAI.Code():
var cfg entity.ChatModelConfigOpenAI
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析OpenAI配置失败: %v", err)
}
openAiCfg := openai.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
Temperature: gconv.PtrFloat32(0.7),
MaxCompletionTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
}
set.OpenAI, err = openai.NewChatModel(ctx, &openAiCfg)
case *model.ModelConfigTypeChatQianfan.Code():
var cfg entity.ChatModelConfigQianfan
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析千帆配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewChatModel(ctx, &qianfan.ChatModelConfig{
Model: cfg.Model,
Temperature: gconv.PtrFloat32(0.7),
MaxCompletionTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
case *model.ModelConfigTypeChatQwen.Code():
var cfg entity.ChatModelConfigQwen
if err = gconv.Struct(modelDO.ConfigContent, &cfg); err != nil {
return fmt.Errorf("解析Qwen配置失败: %v", err)
}
set.Qwen, err = qwen.NewChatModel(ctx, &qwen.ChatModelConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
BaseURL: cfg.BaseURL,
Temperature: gconv.PtrFloat32(0.7),
MaxTokens: gconv.PtrInt(1024),
TopP: gconv.PtrFloat32(1.0),
})
default:
return fmt.Errorf("不支持的对话模型类型: %v", *modelDO.ConfigType)
}
if err != nil {
return fmt.Errorf("初始化对话模型失败: %v", err)
}
// 无锁存入租户 map
tenantChatModels[modelDO.TenantId] = set
g.Log().Infof(ctx, "租户[%v]对话模型[%v]初始化成功", modelDO.TenantId, *modelDO.ConfigType)
return
}
func GetTenantChatModel(tenantId uint64) (*ChatModelSet, error) {
set := tenantChatModels[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]对话模型未初始化", tenantId)
}
return set, nil
}
func GetTenantChatModelByType(ctx context.Context, configType model.ModelConfigType) (modelChat.BaseChatModel, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantChatModel(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeChatArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeChatArkBot.Code():
return set.ArkBot, nil
case *model.ModelConfigTypeChatClaude.Code():
return set.Claude, nil
case *model.ModelConfigTypeChatDeepSeek.Code():
return set.DeepSeek, nil
case *model.ModelConfigTypeChatOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeChatOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeChatQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeChatQwen.Code():
return set.Qwen, nil
default:
return nil, fmt.Errorf("不支持的对话模型类型: %v", configType)
}
}
func RefreshTenantChatModel(ctx context.Context, modelDO *entity.Model) error {
delete(tenantChatModels, modelDO.TenantId)
return InitChat(ctx, modelDO)
}