package eino import ( "context" "fmt" "rag/consts/model" "rag/model/entity" "gitea.com/red-future/common/jaeger" "gitea.com/red-future/common/utils" "github.com/cloudwego/eino-ext/components/embedding/ark" "github.com/cloudwego/eino-ext/components/embedding/dashscope" "github.com/cloudwego/eino-ext/components/embedding/ollama" "github.com/cloudwego/eino-ext/components/embedding/openai" "github.com/cloudwego/eino-ext/components/embedding/qianfan" "github.com/cloudwego/eino-ext/components/embedding/tencentcloud" "github.com/cloudwego/eino/components/embedding" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" ) type EmbedderSet struct { Ark *ark.Embedder Ollama *ollama.Embedder OpenAI *openai.Embedder Qianfan *qianfan.Embedder TencentCloud *tencentcloud.Embedder DashScope *dashscope.Embedder } // 全局租户容器:key=tenantId,value=该租户的向量模型 var tenantEmbedders = make(map[uint64]*EmbedderSet) func init() { ctx := context.Background() ctx, span := jaeger.NewSpan(ctx, "InitAllVector") defer span.End() InitAllVector(ctx) return } // ===================== 1. 服务启动时调用:初始化所有租户 ===================== func InitAllVector(ctx context.Context) { //list, err := dao.Model.GetNoTenantId(ctx, &dto.GetModelReq{ // ModelType: model.ModelTypeVector.Code(), //}) //if err != nil { // g.Log().Errorf(ctx, "获取所有租户ID失败: %v", err) // return //} // //for _, l := range list { // err = InitVector(ctx, l) // if err != nil { // g.Log().Errorf(ctx, "初始化租户[%v]的向量模型失败: %v", l.TenantId, err) // continue // } //} modelDO := new(entity.Model) modelDO.TenantId = 1 modelDO.ConfigType = model.ModelConfigTypeVectorDashScope.Code() var cfg entity.VectorModelConfigDashScope cfg.APIKey = "sk-4a8b82770bf74bc490eb3e4c5a8e2be9" cfg.Model = "text-embedding-v3" modelDO.ConfigContent = gconv.Map(&cfg) err := InitVector(ctx, modelDO) if err != nil { g.Log().Errorf(ctx, "初始化向量模型失败: %v", err) return } } func InitVector(ctx context.Context, modelDO *entity.Model) (err error) { set := &EmbedderSet{} switch *modelDO.ConfigType { case *model.ModelConfigTypeVectorArk.Code(): // 解析 Ark 向量配置 var cfg entity.VectorModelConfigArk err = gconv.Struct(modelDO.ConfigContent, &cfg) if err != nil { return fmt.Errorf("解析Ark向量配置失败: %v", err) } arkCfg := &ark.EmbeddingConfig{ APIKey: cfg.APIKey, Model: cfg.Model, } if !g.IsEmpty(cfg.APIType) { arkCfg.APIType = new(ark.APIType(cfg.APIType)) } set.Ark, err = ark.NewEmbedder(ctx, arkCfg) case *model.ModelConfigTypeVectorOllama.Code(): // 解析 Ollama 向量配置 var cfg entity.VectorModelConfigOllama err = gconv.Struct(modelDO.ConfigContent, &cfg) if err != nil { return fmt.Errorf("解析Ollama向量配置失败: %v", err) } set.Ollama, err = ollama.NewEmbedder(ctx, &ollama.EmbeddingConfig{ BaseURL: cfg.BaseURL, Model: cfg.Model, }) case *model.ModelConfigTypeVectorOpenAI.Code(): // 解析 OpenAI 向量配置 var cfg entity.VectorModelConfigOpenAI err = gconv.Struct(modelDO.ConfigContent, &cfg) if err != nil { return fmt.Errorf("解析OpenAI向量配置失败: %v", err) } openaiCfg := &openai.EmbeddingConfig{ APIKey: cfg.APIKey, Model: cfg.Model, ByAzure: cfg.ByAzure, BaseURL: cfg.BaseURL, APIVersion: cfg.APIVersion, } set.OpenAI, err = openai.NewEmbedder(ctx, openaiCfg) case *model.ModelConfigTypeVectorQianfan.Code(): // 解析 千帆 向量配置 var cfg entity.VectorModelConfigQianfan err = gconv.Struct(modelDO.ConfigContent, &cfg) if err != nil { return fmt.Errorf("解析千帆向量配置失败: %v", err) } qcfg := qianfan.GetQianfanSingletonConfig() qcfg.AccessKey = cfg.AccessKey qcfg.SecretKey = cfg.SecretKey set.Qianfan, err = qianfan.NewEmbedder(ctx, &qianfan.EmbeddingConfig{ Model: cfg.Model, }) case *model.ModelConfigTypeVectorTencentCloud.Code(): // 解析 腾讯云 向量配置 var cfg entity.VectorModelConfigTencentCloud err = gconv.Struct(modelDO.ConfigContent, &cfg) if err != nil { return fmt.Errorf("解析腾讯云向量配置失败: %v", err) } set.TencentCloud, err = tencentcloud.NewEmbedder(ctx, &tencentcloud.EmbeddingConfig{ SecretID: cfg.SecretID, SecretKey: cfg.SecretKey, Region: cfg.Region, }) case *model.ModelConfigTypeVectorDashScope.Code(): // 解析 阿里 dashscope 向量配置 var cfg entity.VectorModelConfigDashScope err = gconv.Struct(modelDO.ConfigContent, &cfg) if err != nil { return fmt.Errorf("解析阿里dashscope向量配置失败: %v", err) } set.DashScope, err = dashscope.NewEmbedder(ctx, &dashscope.EmbeddingConfig{ APIKey: cfg.APIKey, Model: cfg.Model, }) default: return fmt.Errorf("不支持的向量模型配置类型: %v", *modelDO.ConfigType) } // 统一错误处理 if err != nil { return fmt.Errorf("初始化向量模型失败: %v", err) } // 直接存入 map(无锁,重复初始化会直接覆盖) tenantEmbedders[modelDO.TenantId] = set g.Log().Infof(ctx, "向量模型[%v]初始化成功", modelDO.ConfigType) return } func GetTenantEmbedder(tenantId uint64) (*EmbedderSet, error) { set := tenantEmbedders[tenantId] if set == nil { return nil, fmt.Errorf("租户[%v]的向量模型未初始化", tenantId) } return set, nil } func GetTenantEmbedderByType(ctx context.Context, configType model.ModelConfigType) (embedding.Embedder, error) { userInfo, err := utils.GetUserInfo(ctx) if err != nil { return nil, err } set, err := GetTenantEmbedder(userInfo.TenantId) if set == nil { return nil, err } switch *configType { case *model.ModelConfigTypeVectorArk.Code(): return set.Ark, nil case *model.ModelConfigTypeVectorOllama.Code(): return set.Ollama, nil case *model.ModelConfigTypeVectorOpenAI.Code(): return set.OpenAI, nil case *model.ModelConfigTypeVectorQianfan.Code(): return set.Qianfan, nil case *model.ModelConfigTypeVectorTencentCloud.Code(): return set.TencentCloud, nil case *model.ModelConfigTypeVectorDashScope.Code(): return set.DashScope, nil default: return nil, fmt.Errorf("不支持的向量模型配置类型: %v", *configType) } } func RefreshTenantEmbedder(ctx context.Context, modelDO *entity.Model) error { delete(tenantEmbedders, modelDO.TenantId) return InitVector(ctx, modelDO) }