package eino import ( "context" "encoding/json" "fmt" "github.com/cloudwego/eino/schema" "github.com/elastic/go-elasticsearch/v8" "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/cloudwego/eino-ext/components/retriever/es8" "github.com/cloudwego/eino-ext/components/retriever/es8/search_mode" ) func TestRetriever() { ctx := context.Background() client, _ := elasticsearch.NewClient(elasticsearch.Config{ Addresses: []string{"http://localhost:9200"}, }) // 创建 retriever 组件 retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{ Client: client, Index: indexName, TopK: 5, SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{ QueryFieldName: fieldContent, VectorFieldName: fieldContentVector, Hybrid: false, // RRF 仅在特定许可证下可用 // 参见: https://www.elastic.co/subscriptions RRF: false, RRFRankConstant: nil, RRFWindowSize: nil, }), ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) { doc = &schema.Document{ ID: *hit.Id_, Content: "", MetaData: map[string]any{}, } var src map[string]any if err = json.Unmarshal(hit.Source_, &src); err != nil { return nil, err } for field, val := range src { switch field { case fieldContent: doc.Content = val.(string) case fieldContentVector: var v []float64 for _, item := range val.([]interface{}) { v = append(v, item.(float64)) } doc.WithDenseVector(v) case fieldExtraLocation: doc.MetaData[docExtraLocation] = val.(string) } } if hit.Score_ != nil { doc.WithScore(float64(*hit.Score_)) } return doc, nil }, Embedding: EmbedderDashscope, }) // 不带过滤器的搜索 docs, _ := retriever.Retrieve(ctx, "tourist attraction") // 带过滤器的搜索 docs, _ = retriever.Retrieve(ctx, "tourist attraction", es8.WithFilters([]types.Query{{ Term: map[string]types.TermQuery{ fieldExtraLocation: { CaseInsensitive: of(true), Value: "China", }, }, }}), ) fmt.Printf("retrieved docs: %+v\n", docs) } func of[T any](v T) *T { return &v }