95 lines
2.2 KiB
Go
95 lines
2.2 KiB
Go
package eino
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
|
|
"github.com/cloudwego/eino/schema"
|
|
"github.com/elastic/go-elasticsearch/v8"
|
|
"github.com/elastic/go-elasticsearch/v8/typedapi/types"
|
|
|
|
"github.com/cloudwego/eino-ext/components/retriever/es8"
|
|
"github.com/cloudwego/eino-ext/components/retriever/es8/search_mode"
|
|
)
|
|
|
|
func TestRetriever() {
|
|
ctx := context.Background()
|
|
|
|
client, _ := elasticsearch.NewClient(elasticsearch.Config{
|
|
Addresses: []string{"http://localhost:9200"},
|
|
})
|
|
|
|
// 创建 retriever 组件
|
|
retriever, _ := es8.NewRetriever(ctx, &es8.RetrieverConfig{
|
|
Client: client,
|
|
Index: indexName,
|
|
TopK: 5,
|
|
SearchMode: search_mode.SearchModeApproximate(&search_mode.ApproximateConfig{
|
|
QueryFieldName: fieldContent,
|
|
VectorFieldName: fieldContentVector,
|
|
Hybrid: false,
|
|
// RRF 仅在特定许可证下可用
|
|
// 参见: https://www.elastic.co/subscriptions
|
|
RRF: false,
|
|
RRFRankConstant: nil,
|
|
RRFWindowSize: nil,
|
|
}),
|
|
ResultParser: func(ctx context.Context, hit types.Hit) (doc *schema.Document, err error) {
|
|
doc = &schema.Document{
|
|
ID: *hit.Id_,
|
|
Content: "",
|
|
MetaData: map[string]any{},
|
|
}
|
|
|
|
var src map[string]any
|
|
if err = json.Unmarshal(hit.Source_, &src); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
for field, val := range src {
|
|
switch field {
|
|
case fieldContent:
|
|
doc.Content = val.(string)
|
|
case fieldContentVector:
|
|
var v []float64
|
|
for _, item := range val.([]interface{}) {
|
|
v = append(v, item.(float64))
|
|
}
|
|
doc.WithDenseVector(v)
|
|
case fieldExtraLocation:
|
|
doc.MetaData[docExtraLocation] = val.(string)
|
|
}
|
|
}
|
|
|
|
if hit.Score_ != nil {
|
|
doc.WithScore(float64(*hit.Score_))
|
|
}
|
|
|
|
return doc, nil
|
|
},
|
|
Embedding: EmbedderDashscope,
|
|
})
|
|
|
|
// 不带过滤器的搜索
|
|
docs, _ := retriever.Retrieve(ctx, "tourist attraction")
|
|
|
|
// 带过滤器的搜索
|
|
docs, _ = retriever.Retrieve(ctx, "tourist attraction",
|
|
es8.WithFilters([]types.Query{{
|
|
Term: map[string]types.TermQuery{
|
|
fieldExtraLocation: {
|
|
CaseInsensitive: of(true),
|
|
Value: "China",
|
|
},
|
|
},
|
|
}}),
|
|
)
|
|
|
|
fmt.Printf("retrieved docs: %+v\n", docs)
|
|
}
|
|
|
|
func of[T any](v T) *T {
|
|
return &v
|
|
}
|