166 lines
5.5 KiB
Go
166 lines
5.5 KiB
Go
package ragflow
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
)
|
||
|
||
// 数据集管理
|
||
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#数据集管理
|
||
|
||
// Dataset 数据集结构体
|
||
type Dataset struct {
|
||
Id string `json:"id"`
|
||
Name string `json:"name"`
|
||
Avatar string `json:"avatar"`
|
||
TenantId string `json:"tenant_id"`
|
||
Description string `json:"description"`
|
||
Language string `json:"language"`
|
||
EmbeddingModel string `json:"embedding_model"`
|
||
Permission string `json:"permission"`
|
||
DocumentCount int `json:"document_count"`
|
||
ChunkCount int `json:"chunk_count"`
|
||
ParseStatus string `json:"parse_status"`
|
||
CreatedBy string `json:"created_by"`
|
||
CreateTime int64 `json:"create_time"`
|
||
UpdateDate string `json:"update_date"`
|
||
UpdateTime int64 `json:"update_time"`
|
||
Status string `json:"status"`
|
||
ChunkMethod string `json:"chunk_method"`
|
||
ParserConfig map[string]interface{} `json:"parser_config"`
|
||
VectorSimilarityWeight float64 `json:"vector_similarity_weight"`
|
||
SimilarityThreshold float64 `json:"similarity_threshold"`
|
||
TokenNum int `json:"token_num"`
|
||
}
|
||
|
||
// CreateDatasetReq 创建数据集请求
|
||
type CreateDatasetReq struct {
|
||
Name string `json:"name"`
|
||
Avatar string `json:"avatar,omitempty"`
|
||
Description string `json:"description,omitempty"`
|
||
EmbeddingModel string `json:"embedding_model,omitempty"`
|
||
Permission string `json:"permission,omitempty"`
|
||
ChunkMethod string `json:"chunk_method,omitempty"`
|
||
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
|
||
}
|
||
|
||
// UpdateDatasetReq 更新数据集请求
|
||
type UpdateDatasetReq struct {
|
||
Name string `json:"name,omitempty"`
|
||
Avatar string `json:"avatar,omitempty"`
|
||
Description string `json:"description,omitempty"`
|
||
EmbeddingModel string `json:"embedding_model,omitempty"`
|
||
Permission string `json:"permission,omitempty"`
|
||
ChunkMethod string `json:"chunk_method,omitempty"`
|
||
PageRank int `json:"pagerank,omitempty"`
|
||
ParserConfig map[string]interface{} `json:"parser_config,omitempty"`
|
||
}
|
||
|
||
// ListDatasetsReq 列出数据集请求
|
||
type ListDatasetsReq struct {
|
||
Page int `json:"page,omitempty"`
|
||
PageSize int `json:"page_size,omitempty"`
|
||
OrderBy string `json:"orderby,omitempty"`
|
||
Desc bool `json:"desc,omitempty"`
|
||
Name string `json:"name,omitempty"`
|
||
Id string `json:"id,omitempty"`
|
||
}
|
||
|
||
// ListDatasetsRes 列出数据集响应
|
||
type ListDatasetsRes struct {
|
||
Code int `json:"code"`
|
||
Data []*Dataset `json:"data"`
|
||
Total int `json:"total"`
|
||
}
|
||
|
||
// DeleteDatasetsReq 删除数据集请求
|
||
type DeleteDatasetsReq struct {
|
||
Ids []string `json:"ids"`
|
||
}
|
||
|
||
// CreateDataset 创建数据集
|
||
func (c *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*Dataset, error) {
|
||
var res struct {
|
||
Code int `json:"code"`
|
||
Data *Dataset `json:"data"`
|
||
Msg string `json:"message"`
|
||
}
|
||
if err := c.request(ctx, "POST", "/api/v1/datasets", req, &res); err != nil {
|
||
return nil, err
|
||
}
|
||
if res.Code != 0 {
|
||
return nil, fmt.Errorf("create dataset failed: %s", res.Msg)
|
||
}
|
||
return res.Data, nil
|
||
}
|
||
|
||
// ListDatasets 列出数据集
|
||
func (c *Client) ListDatasets(ctx context.Context, req *ListDatasetsReq) (*ListDatasetsRes, error) {
|
||
// 构建查询参数
|
||
path := "/api/v1/datasets?"
|
||
params := map[string]interface{}{}
|
||
if req.Page > 0 {
|
||
params["page"] = req.Page
|
||
}
|
||
if req.PageSize > 0 {
|
||
params["page_size"] = req.PageSize
|
||
}
|
||
if req.OrderBy != "" {
|
||
params["orderby"] = req.OrderBy
|
||
}
|
||
// desc 默认为 true,如果显式设置为 false 才传递,或者根据 API 行为调整
|
||
// 这里简单处理,如果设置了就传
|
||
if req.Desc {
|
||
params["desc"] = "true"
|
||
} else {
|
||
params["desc"] = "false"
|
||
}
|
||
if req.Name != "" {
|
||
params["name"] = req.Name
|
||
}
|
||
if req.Id != "" {
|
||
params["id"] = req.Id
|
||
}
|
||
|
||
// 拼接 query string
|
||
query := buildQueryString(params)
|
||
if query != "" {
|
||
path += "?" + query
|
||
}
|
||
|
||
var res ListDatasetsRes
|
||
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
|
||
return nil, err
|
||
}
|
||
if res.Code != 0 {
|
||
return nil, fmt.Errorf("list datasets failed: code=%d", res.Code)
|
||
}
|
||
return &res, nil
|
||
}
|
||
|
||
// DeleteDataset 删除数据集
|
||
func (c *Client) DeleteDataset(ctx context.Context, ids []string) error {
|
||
req := DeleteDatasetsReq{Ids: ids}
|
||
var res CommonResponse
|
||
if err := c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil {
|
||
return err
|
||
}
|
||
if !res.IsSuccess() {
|
||
return fmt.Errorf("delete dataset failed: %s", res.Message)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// UpdateDataset 更新数据集
|
||
func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) error {
|
||
var res CommonResponse
|
||
path := fmt.Sprintf("/api/v1/datasets/%s", id)
|
||
if err := c.request(ctx, "PUT", path, req, &res); err != nil {
|
||
return err
|
||
}
|
||
if !res.IsSuccess() {
|
||
return fmt.Errorf("update dataset failed: %s", res.Message)
|
||
}
|
||
return nil
|
||
}
|