feat: 支持多模型提供商 embedding

This commit is contained in:
2026-04-01 13:38:33 +08:00
parent bcbe6eba78
commit 2e4a0a89f1
9 changed files with 631 additions and 447 deletions

8
rag/eino/consts.go Normal file
View File

@@ -0,0 +1,8 @@
package eino
const (
providerArk = "ark"
providerOpenai = "openai"
providerQianfan = "qianfan"
providerDashscope = "dashscope"
)

View File

@@ -5,59 +5,60 @@ import (
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive"
"github.com/cloudwego/eino-ext/components/document/transformer/splitter/semantic"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/schema"
"github.com/gogf/gf/v2/frame/g"
)
// 全局只初始化一次
var (
splitter document.Transformer
)
// SemanticSplitDocument 语义分割文档
func SemanticSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
if g.IsEmpty(splitter) {
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
// 读取配置,使用合理的默认值
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
if batchSize <= 0 {
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
}
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
// 读取配置,使用合理的默认值
bufferSize := g.Cfg().MustGet(ctx, "eino.splitter.bufferSize").Int()
minChunkSize := g.Cfg().MustGet(ctx, "eino.splitter.minChunkSize").Int()
percentile := g.Cfg().MustGet(ctx, "eino.splitter.percentile").Float64()
batchSize := g.Cfg().MustGet(ctx, "eino.splitter.batchSize").Int()
if batchSize <= 0 {
batchSize = 10 // doubao-embedding-vision 限制每批最多 10 个
}
// 使用批量包装器
batchEmbedder := NewBatchEmbedder(Embedder, batchSize)
// 使用批量包装器
var batchEmbedder *BatchEmbedder
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
batchEmbedder = NewBatchEmbedder(EmbedderArk, batchSize)
case providerOpenai:
batchEmbedder = NewBatchEmbedder(EmbedderOpenAI, batchSize)
case providerDashscope:
batchEmbedder = NewBatchEmbedder(EmbedderDashscope, batchSize)
}
splitter, err = semantic.NewSplitter(ctx, &semantic.Config{
Embedding: batchEmbedder,
BufferSize: bufferSize,
Percentile: percentile,
Separators: separators,
})
if err != nil {
return
}
splitter, err := semantic.NewSplitter(ctx, &semantic.Config{
Embedding: batchEmbedder,
BufferSize: bufferSize,
MinChunkSize: minChunkSize,
Percentile: percentile,
Separators: separators,
})
if err != nil {
return
}
return splitter.Transform(ctx, docs)
}
// RecursiveSplitDocument 递归分割文档
func RecursiveSplitDocument(ctx context.Context, docs []*schema.Document) (res []*schema.Document, err error) {
if g.IsEmpty(splitter) {
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
splitter, err = recursive.NewSplitter(ctx, &recursive.Config{
ChunkSize: 1500,
OverlapSize: 300,
KeepType: recursive.KeepTypeNone,
Separators: separators,
})
if err != nil {
return
}
// 默认分隔符(支持中英文)
separators := []string{"\n\n", "\n", "。", "", "", "", ".", "!", "?", ";"}
splitter, err := recursive.NewSplitter(ctx, &recursive.Config{
ChunkSize: 512,
OverlapSize: 100,
KeepType: recursive.KeepTypeNone,
Separators: separators,
})
if err != nil {
return
}
return splitter.Transform(ctx, docs)
}

View File

@@ -2,45 +2,68 @@ package eino
import (
"context"
"fmt"
"github.com/cloudwego/eino-ext/components/embedding/ark"
"github.com/cloudwego/eino-ext/components/embedding/dashscope"
"github.com/cloudwego/eino-ext/components/embedding/openai"
"github.com/gogf/gf/v2/frame/g"
"github.com/golang/glog"
)
// 全局只初始化一次
var (
Embedder *dashscope.Embedder // 导出供其他模块使用
EmbedderArk *ark.Embedder
EmbedderDashscope *dashscope.Embedder
EmbedderOpenAI *openai.Embedder
)
// init程序启动时自动执行一次
func init() {
ctx := context.Background()
if !g.Cfg().MustGet(ctx, "eino.embedding").IsEmpty() {
var err error
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
cfg := &ark.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
apiTypeVal := ark.APIType(apiType)
cfg.APIType = &apiTypeVal
}
EmbedderArk, err = ark.NewEmbedder(ctx, cfg)
case providerOpenai:
chatModelConfig := &openai.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderOpenAI, err = openai.NewEmbedder(ctx, chatModelConfig)
case providerDashscope:
cfg := &dashscope.EmbeddingConfig{
APIKey: g.Cfg().MustGet(ctx, "eino.embedding.apiKey").String(),
Model: g.Cfg().MustGet(ctx, "eino.embedding.model").String(),
}
EmbedderDashscope, err = dashscope.NewEmbedder(ctx, cfg)
}
// 检查是否配置了 APIType支持 "text_api" 和 "multi_modal_api"
//if apiType := g.Cfg().MustGet(ctx, "eino.embedding.apiType").String(); apiType != "" {
// apiTypeVal := dashscope.APIType(apiType)
// cfg.APIType = &apiTypeVal
//}
Embedder, err = dashscope.NewEmbedder(ctx, cfg)
if err != nil {
glog.Fatalf("NewEmbedder of ark error: %v", err)
glog.Fatalf("NewEmbedder of %v error: %v", provider, err)
}
//embedding, err := embedder.EmbedStrings(ctx, []string{"hello world", "bye bye"})
//if err != nil {
// log.Printf("embedding error: %v\n", err)
// return
//}
//
//log.Printf("embedding: %v\n", embedding)
}
return
}
func EmbedStrings(ctx context.Context, texts []string) (embeddings [][]float64, err error) {
return Embedder.EmbedStrings(ctx, texts)
provider := g.Cfg().MustGet(ctx, "eino.embedding.provider").String()
switch provider {
case providerArk:
return EmbedderArk.EmbedStrings(ctx, texts)
case providerOpenai:
return EmbedderOpenAI.EmbedStrings(ctx, texts)
case providerDashscope:
return EmbedderDashscope.EmbedStrings(ctx, texts)
}
return nil, fmt.Errorf("unsupported provider: %v", provider)
}