48 lines
1.1 KiB
Go
48 lines
1.1 KiB
Go
package eino
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/cloudwego/eino/components/embedding"
|
|
)
|
|
|
|
// BatchEmbedder 包装器,支持批量限制
|
|
type BatchEmbedder struct {
|
|
embedder embedding.Embedder
|
|
batchSize int
|
|
}
|
|
|
|
// NewBatchEmbedder 创建支持批量限制的 embedding 包装器
|
|
func NewBatchEmbedder(embedder embedding.Embedder, batchSize int) *BatchEmbedder {
|
|
if batchSize <= 0 {
|
|
batchSize = 10 // 默认每批 10 个
|
|
}
|
|
return &BatchEmbedder{
|
|
embedder: embedder,
|
|
batchSize: batchSize,
|
|
}
|
|
}
|
|
|
|
// EmbedStrings 分批调用 embedding
|
|
func (b *BatchEmbedder) EmbedStrings(ctx context.Context, texts []string, opts ...embedding.Option) ([][]float64, error) {
|
|
if len(texts) <= b.batchSize {
|
|
return b.embedder.EmbedStrings(ctx, texts, opts...)
|
|
}
|
|
|
|
var allEmbeddings [][]float64
|
|
for i := 0; i < len(texts); i += b.batchSize {
|
|
end := i + b.batchSize
|
|
if end > len(texts) {
|
|
end = len(texts)
|
|
}
|
|
|
|
batch := texts[i:end]
|
|
embeddings, err := b.embedder.EmbedStrings(ctx, batch, opts...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
allEmbeddings = append(allEmbeddings, embeddings...)
|
|
}
|
|
return allEmbeddings, nil
|
|
}
|