Files
rag/common/eino/embedding.go

214 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package eino
import (
"context"
"fmt"
"rag/consts/model"
"rag/model/entity"
"gitea.com/red-future/common/jaeger"
"gitea.com/red-future/common/utils"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
"github.com/cloudwego/eino-ext/components/embedding/ollama"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/cloudwego/eino-ext/components/embedding/qianfan"
"github.com/cloudwego/eino-ext/components/embedding/tencentcloud"
"github.com/cloudwego/eino/components/embedding"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
type EmbedderSet struct {
Ark *ark.Embedder
Ollama *ollama.Embedder
OpenAI *openai.Embedder
Qianfan *qianfan.Embedder
TencentCloud *tencentcloud.Embedder
DashScope *dashscope.Embedder
}
// 全局租户容器key=tenantIdvalue=该租户的向量模型
var tenantEmbedders = make(map[uint64]*EmbedderSet)
func init() {
ctx := context.Background()
ctx, span := jaeger.NewSpan(ctx, "InitAllVector")
defer span.End()
InitAllVector(ctx)
return
}
// ===================== 1. 服务启动时调用:初始化所有租户 =====================
func InitAllVector(ctx context.Context) {
//list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{
// ModelType: model.ModelTypeVector.Code(),
//})
//if err != nil {
// g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err)
// return
//}
//
//for _, l := range list {
// err = InitVector(ctx, l)
// if err != nil {
// g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err)
// continue
// }
//}
modelDO := new(entity.Model)
modelDO.TenantId = 1
modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code()
var cfg entity.VectorModelConfigDashScope
cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9"
cfg.Model = "text-embedding-v3"
modelDO.ConfigContent = gconv.Map(&cfg)
err := InitVector(ctx, modelDO)
if err != nil {
g.Log().Errorf(ctx, "初始化向量模型失败: %v", err)
return
}
}
func InitVector(ctx context.Context, modelDO *entity.Model) (err error) {
set := &EmbedderSet{}
switch *modelDO.ConfigType {
case *model.ModelConfigTypeVectorArk.Code():
// 解析 Ark 向量配置
var cfg entity.VectorModelConfigArk
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ark向量配置失败: %v", err)
}
arkCfg := &ark.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
}
if !g.IsEmpty(cfg.APIType) {
arkCfg.APIType = new(ark.APIType(cfg.APIType))
}
set.Ark, err = ark.NewEmbedder(ctx, arkCfg)
case *model.ModelConfigTypeVectorOllama.Code():
// 解析 Ollama 向量配置
var cfg entity.VectorModelConfigOllama
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析Ollama向量配置失败: %v", err)
}
set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{
BaseURL: cfg.BaseURL,
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorOpenAI.Code():
// 解析 OpenAI 向量配置
var cfg entity.VectorModelConfigOpenAI
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析OpenAI向量配置失败: %v", err)
}
openaiCfg := &openai.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
ByAzure: cfg.ByAzure,
BaseURL: cfg.BaseURL,
APIVersion: cfg.APIVersion,
}
set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg)
case *model.ModelConfigTypeVectorQianfan.Code():
// 解析 千帆 向量配置
var cfg entity.VectorModelConfigQianfan
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析千帆向量配置失败: %v", err)
}
qcfg := qianfan.GetQianfanSingletonConfig()
qcfg.AccessKey = cfg.AccessKey
qcfg.SecretKey = cfg.SecretKey
set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{
Model: cfg.Model,
})
case *model.ModelConfigTypeVectorTencentCloud.Code():
// 解析 腾讯云 向量配置
var cfg entity.VectorModelConfigTencentCloud
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析腾讯云向量配置失败: %v", err)
}
set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{
SecretID: cfg.SecretID,
SecretKey: cfg.SecretKey,
Region: cfg.Region,
})
case *model.ModelConfigTypeVectorDashScope.Code():
// 解析 阿里 dashscope 向量配置
var cfg entity.VectorModelConfigDashScope
err = gconv.Struct(modelDO.ConfigContent, &cfg)
if err != nil {
return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err)
}
set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{
APIKey: cfg.APIKey,
Model: cfg.Model,
})
default:
return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType)
}
// 统一错误处理
if err != nil {
return fmt.Errorf("初始化向量模型失败: %v", err)
}
// 直接存入 map无锁重复初始化会直接覆盖
tenantEmbedders[modelDO.TenantId] = set
g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType)
return
}
func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) {
set := tenantEmbedders[tenantId]
if set == nil {
return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId)
}
return set, nil
}
func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) {
userInfo, err := utils.GetUserInfo(ctx)
if err != nil {
return nil, err
}
set, err := GetTenantEmbedder(userInfo.TenantId)
if set == nil {
return nil, err
}
switch *configType {
case *model.ModelConfigTypeVectorArk.Code():
return set.Ark, nil
case *model.ModelConfigTypeVectorOllama.Code():
return set.Ollama, nil
case *model.ModelConfigTypeVectorOpenAI.Code():
return set.OpenAI, nil
case *model.ModelConfigTypeVectorQianfan.Code():
return set.Qianfan, nil
case *model.ModelConfigTypeVectorTencentCloud.Code():
return set.TencentCloud, nil
case *model.ModelConfigTypeVectorDashScope.Code():
return set.DashScope, nil
default:
return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType)
}
}
func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error {
delete(tenantEmbedders, modelDO.TenantId)
return InitVector(ctx, modelDO)
}