Files
rag/common/eino/c.go
2026-04-03 11:14:44 +08:00

95 lines
2.2 KiB
Go

package eino
import (
"context"
"encoding/json"
"fmt"
"github.com/cloudwego/eino/schema"
"github.com/elastic/go-elasticsearch/v8"
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
"github.com/cloudwego/eino-ext/components/retriever/es8"
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
)
func TestRetriever() {
ctx := context.Background()
client, _ := elasticsearch.NewClient(elasticsearch.Config{
Addresses: []string{"http://localhost:9200"},
})
// 创建 retriever 组件
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
Client: client,
Index: indexName,
TopK: 5,
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
QueryFieldName: fieldContent,
VectorFieldName: fieldContentVector,
Hybrid: false,
// RRF 仅在特定许可证下可用
// 参见: https://www.elastic.co/subscriptions
RRF: false,
RRFRankConstant: nil,
RRFWindowSize: nil,
}),
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
doc = &schema.Document{
ID: *hit.Id_,
Content: "",
MetaData: map[string]any{},
}
var src map[string]any
if err = json.Unmarshal(hit.Source_, &src); err != nil {
return nil, err
}
for field, val := range src {
switch field {
case fieldContent:
doc.Content = val.(string)
case fieldContentVector:
var v []float64
for _, item := range val.([]interface{}) {
v = append(v, item.(float64))
}
doc.WithDenseVector(v)
case fieldExtraLocation:
doc.MetaData[docExtraLocation] = val.(string)
}
}
if hit.Score_ != nil {
doc.WithScore(float64(*hit.Score_))
}
return doc, nil
},
Embedding: EmbedderDashscope,
})
// 不带过滤器的搜索
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
// 带过滤器的搜索
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
es8.WithFilters([]types.Query{{
Term: map[string]types.TermQuery{
fieldExtraLocation: {
CaseInsensitive: of(true),
Value: "China",
},
},
}}),
)
fmt.Printf("retrieved docs: %+v\n", docs)
}
func of[T any](v T) *T {
return &v
}