diff --git a/ragflow/chat.go b/ragflow/chat.go new file mode 100644 index 0000000..addc182 --- /dev/null +++ b/ragflow/chat.go @@ -0,0 +1,172 @@ +package ragflow + +import ( + "context" + "fmt" +) + +// Chat 结构体 +type Chat struct { + Id string `json:"id"` + Name string `json:"name"` + Avatar string `json:"avatar"` + DatasetIds []string `json:"dataset_ids"` + Llm Llm `json:"llm"` + Prompt Prompt `json:"prompt"` + Description string `json:"description"` + DoRefer string `json:"do_refer"` + Language string `json:"language"` + PromptType string `json:"prompt_type"` + Status string `json:"status"` + TenantId string `json:"tenant_id"` + TopK int `json:"top_k"` + CreateDate string `json:"create_date"` + CreateTime int64 `json:"create_time"` + UpdateDate string `json:"update_date"` + UpdateTime int64 `json:"update_time"` +} + +type Llm struct { + ModelName string `json:"model_name,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` +} + +type Prompt struct { + SimilarityThreshold float64 `json:"similarity_threshold,omitempty"` + KeywordsSimilarityWeight float64 `json:"keywords_similarity_weight,omitempty"` + Opener string `json:"opener,omitempty"` + Prompt string `json:"prompt,omitempty"` + RerankModel string `json:"rerank_model,omitempty"` + TopN int `json:"top_n,omitempty"` + Variables []Variable `json:"variables,omitempty"` + EmptyResponse string `json:"empty_response,omitempty"` +} + +type Variable struct { + Key string `json:"key"` + Optional bool `json:"optional"` +} + +// CreateChatReq 创建聊天助手请求 +type CreateChatReq struct { + Name string `json:"name"` + Avatar string `json:"avatar,omitempty"` + DatasetIds []string `json:"dataset_ids,omitempty"` + Llm *Llm `json:"llm,omitempty"` + Prompt *Prompt `json:"prompt,omitempty"` +} + +// UpdateChatReq 更新聊天助手请求 +type UpdateChatReq struct { + Name string `json:"name,omitempty"` + Avatar string `json:"avatar,omitempty"` + DatasetIds []string `json:"dataset_ids,omitempty"` + Llm *Llm `json:"llm,omitempty"` + Prompt *Prompt `json:"prompt,omitempty"` +} + +// ListChatsReq 列出聊天助手请求 +type ListChatsReq struct { + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + OrderBy string `json:"orderby,omitempty"` + Desc bool `json:"desc,omitempty"` + Name string `json:"name,omitempty"` + Id string `json:"id,omitempty"` +} + +// ListChatsRes 列出聊天助手响应 +type ListChatsRes struct { + Code int `json:"code"` + Data []*Chat `json:"data"` + Total int `json:"total"` // API 文档中未明确 total 字段,但通常列表接口会有 +} + +// DeleteChatsReq 删除聊天助手请求 +type DeleteChatsReq struct { + Ids []string `json:"ids"` +} + +// CreateChat 创建聊天助手 +func (c *Client) CreateChat(ctx context.Context, req *CreateChatReq) (*Chat, error) { + var res struct { + Code int `json:"code"` + Data *Chat `json:"data"` + Msg string `json:"message"` + } + if err := c.request(ctx, "POST", "/api/v1/chats", req, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("create chat failed: %s", res.Msg) + } + return res.Data, nil +} + +// ListChats 列出聊天助手 +func (c *Client) ListChats(ctx context.Context, req *ListChatsReq) (*ListChatsRes, error) { + path := "/api/v1/chats?" + params := map[string]interface{}{} + if req.Page > 0 { + params["page"] = req.Page + } + if req.PageSize > 0 { + params["page_size"] = req.PageSize + } + if req.OrderBy != "" { + params["orderby"] = req.OrderBy + } + if req.Desc { + params["desc"] = "true" + } else { + params["desc"] = "false" + } + if req.Name != "" { + params["name"] = req.Name + } + if req.Id != "" { + params["id"] = req.Id + } + + for k, v := range params { + path += fmt.Sprintf("%s=%v&", k, v) + } + + var res ListChatsRes + if err := c.request(ctx, "GET", path, nil, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("list chats failed: code=%d", res.Code) + } + return &res, nil +} + +// DeleteChats 删除聊天助手 +func (c *Client) DeleteChats(ctx context.Context, ids []string) error { + req := DeleteChatsReq{Ids: ids} + var res CommonResponse + if err := c.request(ctx, "DELETE", "/api/v1/chats", req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("delete chats failed: %s", res.Message) + } + return nil +} + +// UpdateChat 更新聊天助手 +func (c *Client) UpdateChat(ctx context.Context, id string, req *UpdateChatReq) error { + var res CommonResponse + path := fmt.Sprintf("/api/v1/chats/%s", id) + if err := c.request(ctx, "PUT", path, req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("update chat failed: %s", res.Message) + } + return nil +} diff --git a/ragflow/chunk.go b/ragflow/chunk.go new file mode 100644 index 0000000..e2c2182 --- /dev/null +++ b/ragflow/chunk.go @@ -0,0 +1,174 @@ +package ragflow + +import ( + "context" + "fmt" +) + +// Chunk 结构体 +type Chunk struct { + Id string `json:"id"` + Content string `json:"content"` + DocumentId string `json:"document_id"` + DatasetId string `json:"dataset_id"` + CreateTime string `json:"create_time"` + CreateTimestamp float64 `json:"create_timestamp"` + ImportantKeywords []string `json:"important_keywords"` + Questions []string `json:"questions"` + Available bool `json:"available"` + ImageId string `json:"image_id"` + Positions []string `json:"positions"` +} + +// AddChunkReq 添加知识块请求 +type AddChunkReq struct { + Content string `json:"content"` + ImportantKeywords []string `json:"important_keywords,omitempty"` + Questions []string `json:"questions,omitempty"` +} + +// ListChunksReq 列出知识块请求 +type ListChunksReq struct { + Keywords string `json:"keywords,omitempty"` + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + Id string `json:"id,omitempty"` +} + +// ListChunksRes 列出知识块响应 +type ListChunksRes struct { + Code int `json:"code"` + Data struct { + Chunks []*Chunk `json:"chunks"` + Doc interface{} `json:"doc"` // 文档信息,暂时用 interface{} + Total int `json:"total"` + } `json:"data"` +} + +// DeleteChunksReq 删除知识块请求 +type DeleteChunksReq struct { + ChunkIds []string `json:"chunk_ids,omitempty"` // 如果为空,删除所有 +} + +// UpdateChunkReq 更新知识块请求 +type UpdateChunkReq struct { + Content string `json:"content,omitempty"` + ImportantKeywords []string `json:"important_keywords,omitempty"` + Available *bool `json:"available,omitempty"` +} + +// RetrieveChunksReq 检索知识块请求 +type RetrieveChunksReq struct { + Question string `json:"question"` + DatasetIds []string `json:"dataset_ids,omitempty"` + DocumentIds []string `json:"document_ids,omitempty"` + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + SimilarityThreshold float64 `json:"similarity_threshold,omitempty"` + VectorSimilarityWeight float64 `json:"vector_similarity_weight,omitempty"` + TopK int `json:"top_k,omitempty"` + RerankId string `json:"rerank_id,omitempty"` + Keyword bool `json:"keyword,omitempty"` + Highlight bool `json:"highlight,omitempty"` + CrossLanguages []string `json:"cross_languages,omitempty"` + MetadataCondition map[string]interface{} `json:"metadata_condition,omitempty"` +} + +// RetrieveChunksRes 检索知识块响应 (结构比较复杂,暂时简化,根据实际返回调整) +// 官方文档未给出详细响应结构,假设返回 chunks 列表 +type RetrieveChunksRes struct { + Code int `json:"code"` + Data struct { + Chunks []interface{} `json:"chunks"` // 检索结果可能包含额外信息 + Total int `json:"total"` + } `json:"data"` +} + +// AddChunk 添加知识块 +func (c *Client) AddChunk(ctx context.Context, datasetId, documentId string, req *AddChunkReq) (*Chunk, error) { + path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId) + var res struct { + Code int `json:"code"` + Data struct { + Chunk *Chunk `json:"chunk"` + } `json:"data"` + Msg string `json:"message"` + } + if err := c.request(ctx, "POST", path, req, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("add chunk failed: %s", res.Msg) + } + return res.Data.Chunk, nil +} + +// ListChunks 列出知识块 +func (c *Client) ListChunks(ctx context.Context, datasetId, documentId string, req *ListChunksReq) (*ListChunksRes, error) { + path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks?", datasetId, documentId) + params := map[string]interface{}{} + if req.Keywords != "" { + params["keywords"] = req.Keywords + } + if req.Page > 0 { + params["page"] = req.Page + } + if req.PageSize > 0 { + params["page_size"] = req.PageSize + } + if req.Id != "" { + params["id"] = req.Id + } + + for k, v := range params { + path += fmt.Sprintf("%s=%v&", k, v) + } + + var res ListChunksRes + if err := c.request(ctx, "GET", path, nil, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("list chunks failed: code=%d", res.Code) + } + return &res, nil +} + +// DeleteChunks 删除知识块 +func (c *Client) DeleteChunks(ctx context.Context, datasetId, documentId string, chunkIds []string) error { + req := DeleteChunksReq{ChunkIds: chunkIds} + var res CommonResponse + path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks", datasetId, documentId) + if err := c.request(ctx, "DELETE", path, req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("delete chunks failed: %s", res.Message) + } + return nil +} + +// UpdateChunk 更新知识块 +func (c *Client) UpdateChunk(ctx context.Context, datasetId, documentId, chunkId string, req *UpdateChunkReq) error { + var res CommonResponse + path := fmt.Sprintf("/api/v1/datasets/%s/documents/%s/chunks/%s", datasetId, documentId, chunkId) + if err := c.request(ctx, "PUT", path, req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("update chunk failed: %s", res.Message) + } + return nil +} + +// RetrieveChunks 检索知识块 +func (c *Client) RetrieveChunks(ctx context.Context, req *RetrieveChunksReq) (*RetrieveChunksRes, error) { + var res RetrieveChunksRes + if err := c.request(ctx, "POST", "/api/v1/retrieval", req, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("retrieve chunks failed: code=%d", res.Code) + } + return &res, nil +} diff --git a/ragflow/dataset.go b/ragflow/dataset.go new file mode 100644 index 0000000..ea0de16 --- /dev/null +++ b/ragflow/dataset.go @@ -0,0 +1,161 @@ +package ragflow + +import ( + "context" + "fmt" +) + +// Dataset 结构体 +type Dataset struct { + Id string `json:"id"` + Name string `json:"name"` + Avatar string `json:"avatar"` + TenantId string `json:"tenant_id"` + Description string `json:"description"` + Language string `json:"language"` + EmbeddingModel string `json:"embedding_model"` + Permission string `json:"permission"` + DocumentCount int `json:"document_count"` + ChunkCount int `json:"chunk_count"` + ParseStatus string `json:"parse_status"` + CreatedBy string `json:"created_by"` + CreateTime int64 `json:"create_time"` + UpdateDate string `json:"update_date"` + UpdateTime int64 `json:"update_time"` + Status string `json:"status"` + ChunkMethod string `json:"chunk_method"` + ParserConfig map[string]interface{} `json:"parser_config"` + VectorSimilarityWeight float64 `json:"vector_similarity_weight"` + SimilarityThreshold float64 `json:"similarity_threshold"` + TokenNum int `json:"token_num"` +} + +// CreateDatasetReq 创建数据集请求 +type CreateDatasetReq struct { + Name string `json:"name"` + Avatar string `json:"avatar,omitempty"` + Description string `json:"description,omitempty"` + EmbeddingModel string `json:"embedding_model,omitempty"` + Permission string `json:"permission,omitempty"` + ChunkMethod string `json:"chunk_method,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config,omitempty"` +} + +// UpdateDatasetReq 更新数据集请求 +type UpdateDatasetReq struct { + Name string `json:"name,omitempty"` + Avatar string `json:"avatar,omitempty"` + Description string `json:"description,omitempty"` + EmbeddingModel string `json:"embedding_model,omitempty"` + Permission string `json:"permission,omitempty"` + ChunkMethod string `json:"chunk_method,omitempty"` + PageRank int `json:"pagerank,omitempty"` + ParserConfig map[string]interface{} `json:"parser_config,omitempty"` +} + +// ListDatasetsReq 列出数据集请求 +type ListDatasetsReq struct { + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + OrderBy string `json:"orderby,omitempty"` + Desc bool `json:"desc,omitempty"` + Name string `json:"name,omitempty"` + Id string `json:"id,omitempty"` +} + +// ListDatasetsRes 列出数据集响应 +type ListDatasetsRes struct { + Code int `json:"code"` + Data []*Dataset `json:"data"` + Total int `json:"total"` +} + +// DeleteDatasetsReq 删除数据集请求 +type DeleteDatasetsReq struct { + Ids []string `json:"ids"` +} + +// CreateDataset 创建数据集 +func (c *Client) CreateDataset(ctx context.Context, req *CreateDatasetReq) (*Dataset, error) { + var res struct { + Code int `json:"code"` + Data *Dataset `json:"data"` + Msg string `json:"message"` + } + if err := c.request(ctx, "POST", "/api/v1/datasets", req, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("create dataset failed: %s", res.Msg) + } + return res.Data, nil +} + +// ListDatasets 列出数据集 +func (c *Client) ListDatasets(ctx context.Context, req *ListDatasetsReq) (*ListDatasetsRes, error) { + // 构建查询参数 + path := "/api/v1/datasets?" + params := map[string]interface{}{} + if req.Page > 0 { + params["page"] = req.Page + } + if req.PageSize > 0 { + params["page_size"] = req.PageSize + } + if req.OrderBy != "" { + params["orderby"] = req.OrderBy + } + // desc 默认为 true,如果显式设置为 false 才传递,或者根据 API 行为调整 + // 这里简单处理,如果设置了就传 + if req.Desc { + params["desc"] = "true" + } else { + params["desc"] = "false" + } + if req.Name != "" { + params["name"] = req.Name + } + if req.Id != "" { + params["id"] = req.Id + } + + // 拼接 query string + for k, v := range params { + path += fmt.Sprintf("%s=%v&", k, v) + } + + var res ListDatasetsRes + if err := c.request(ctx, "GET", path, nil, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("list datasets failed: code=%d", res.Code) + } + return &res, nil +} + +// DeleteDataset 删除数据集 +func (c *Client) DeleteDataset(ctx context.Context, ids []string) error { + req := DeleteDatasetsReq{Ids: ids} + var res CommonResponse + if err := c.request(ctx, "DELETE", "/api/v1/datasets", req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("delete dataset failed: %s", res.Message) + } + return nil +} + +// UpdateDataset 更新数据集 +func (c *Client) UpdateDataset(ctx context.Context, id string, req *UpdateDatasetReq) error { + var res CommonResponse + path := fmt.Sprintf("/api/v1/datasets/%s", id) + if err := c.request(ctx, "PUT", path, req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("update dataset failed: %s", res.Message) + } + return nil +} diff --git a/ragflow/document.go b/ragflow/document.go new file mode 100644 index 0000000..0d2e5e8 --- /dev/null +++ b/ragflow/document.go @@ -0,0 +1,132 @@ +package ragflow + +import ( + "context" + "fmt" +) + +// Document 结构体 +type Document struct { + Id string `json:"id"` + DatasetId string `json:"dataset_id"` + Name string `json:"name"` + Size int64 `json:"size"` + Location string `json:"location"` + CreatedBy string `json:"created_by"` + CreateTime int64 `json:"create_time"` + Thumbnail string `json:"thumbnail"` + Type string `json:"type"` + RunStatus string `json:"run_status"` // 对应 API 返回的 "run" 字段,可能需要确认 + Status string `json:"status"` + ChunkMethod string `json:"chunk_method"` + ParserConfig map[string]interface{} `json:"parser_config"` + TokenNum int `json:"token_num"` + ChunkCount int `json:"chunk_count"` + ProcessBegin int64 `json:"process_begin"` + ProcessDu int64 `json:"process_du"` + Progress float64 `json:"progress"` + ProgressMsg string `json:"progress_msg"` +} + +// UploadDocumentReq 上传文档请求 +// 注意:上传文件通常需要 multipart/form-data,这里仅定义结构,实际逻辑在方法中处理 +type UploadDocumentReq struct { + FilePaths []string // 本地文件路径列表 +} + +// ListDocumentsReq 列出文档请求 +type ListDocumentsReq struct { + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + OrderBy string `json:"orderby,omitempty"` + Desc bool `json:"desc,omitempty"` + Keywords string `json:"keywords,omitempty"` + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + CreateTimeFrom int64 `json:"create_time_from,omitempty"` + CreateTimeTo int64 `json:"create_time_to,omitempty"` +} + +// ListDocumentsRes 列出文档响应 +type ListDocumentsRes struct { + Code int `json:"code"` + Data []*Document `json:"data"` + Total int `json:"total"` +} + +// DeleteDocumentsReq 删除文档请求 +type DeleteDocumentsReq struct { + Ids []string `json:"ids"` +} + +// ListDocuments 列出文档 +func (c *Client) ListDocuments(ctx context.Context, datasetId string, req *ListDocumentsReq) (*ListDocumentsRes, error) { + path := fmt.Sprintf("/api/v1/datasets/%s/documents?", datasetId) + params := map[string]interface{}{} + if req.Page > 0 { + params["page"] = req.Page + } + if req.PageSize > 0 { + params["page_size"] = req.PageSize + } + if req.OrderBy != "" { + params["orderby"] = req.OrderBy + } + if req.Desc { + params["desc"] = "true" + } else { + params["desc"] = "false" + } + if req.Keywords != "" { + params["keywords"] = req.Keywords + } + if req.Id != "" { + params["id"] = req.Id + } + if req.Name != "" { + params["name"] = req.Name + } + if req.CreateTimeFrom > 0 { + params["create_time_from"] = req.CreateTimeFrom + } + if req.CreateTimeTo > 0 { + params["create_time_to"] = req.CreateTimeTo + } + + for k, v := range params { + path += fmt.Sprintf("%s=%v&", k, v) + } + + var res ListDocumentsRes + if err := c.request(ctx, "GET", path, nil, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("list documents failed: code=%d", res.Code) + } + return &res, nil +} + +// UploadDocument 上传文档 +// 注意:此方法需要特殊处理 multipart/form-data,目前的 request 方法可能不支持 +// 我们需要扩展 request 方法或在此处单独实现 +func (c *Client) UploadDocument(ctx context.Context, datasetId string, filePaths []string) error { + // TODO: 实现文件上传逻辑,需要使用 gclient 的 UploadFile 功能 + // 由于 request 方法封装了 JSON 处理,这里可能需要绕过 request 方法直接使用 c.Client + // 暂时留空或仅做简单提示,待完善 Client 封装以支持文件上传 + return fmt.Errorf("upload document not implemented yet") +} + +// DeleteDocument 删除文档 +func (c *Client) DeleteDocument(ctx context.Context, datasetId string, ids []string) error { + req := DeleteDocumentsReq{Ids: ids} + var res CommonResponse + path := fmt.Sprintf("/api/v1/datasets/%s/documents", datasetId) + if err := c.request(ctx, "DELETE", path, req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("delete document failed: %s", res.Message) + } + return nil +} diff --git a/ragflow/ragflow.go b/ragflow/ragflow.go deleted file mode 100644 index 78bf86d..0000000 --- a/ragflow/ragflow.go +++ /dev/null @@ -1,5 +0,0 @@ -package ragflow - -func init() { - -} diff --git a/ragflow/ragflow_dto.go b/ragflow/ragflow_dto.go deleted file mode 100644 index 3cdcbe1..0000000 --- a/ragflow/ragflow_dto.go +++ /dev/null @@ -1 +0,0 @@ -package ragflow diff --git a/ragflow/service/client.go b/ragflow/service/client.go new file mode 100644 index 0000000..4a45dfe --- /dev/null +++ b/ragflow/service/client.go @@ -0,0 +1,77 @@ +package service + +import ( + "context" + "fmt" + "time" + + "gitee.com/red-future---jilin-g/common/ragflow/dto" + "github.com/gogf/gf/v2/encoding/gjson" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/gclient" +) + +type Client struct { + BaseURL string + ApiKey string + Client *gclient.Client +} + +// NewClient 创建一个新的 RAGFlow 客户端 +func NewClient(baseUrl, apiKey string) *Client { + return &Client{ + BaseURL: baseUrl, + ApiKey: apiKey, + Client: g.Client().SetTimeout(30 * time.Second), + } +} + +// request 发送 HTTP 请求 +func (c *Client) request(ctx context.Context, method, path string, data interface{}, result interface{}) error { + url := fmt.Sprintf("%s%s", c.BaseURL, path) + + req := c.Client.Header(map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", c.ApiKey), + "Content-Type": "application/json", + }) + + var res *gclient.Response + var err error + + switch method { + case "GET": + res, err = req.Get(ctx, url, data) + case "POST": + res, err = req.Post(ctx, url, data) + case "PUT": + res, err = req.Put(ctx, url, data) + case "DELETE": + res, err = req.Delete(ctx, url, data) + default: + return fmt.Errorf("unsupported method: %s", method) + } + + if err != nil { + return err + } + defer res.Close() + + // 读取响应体 + body := res.ReadAllString() + + // 解析响应 + if result != nil { + if err := gjson.DecodeTo(body, result); err != nil { + return fmt.Errorf("failed to decode response: %v, body: %s", err, body) + } + + // 检查业务错误码 + if commonRes, ok := result.(*dto.CommonResponse); ok { + if !commonRes.IsSuccess() { + return fmt.Errorf("api error: code=%d, message=%s", commonRes.Code, commonRes.Message) + } + } + } + + return nil +} diff --git a/ragflow/session.go b/ragflow/session.go new file mode 100644 index 0000000..0c15136 --- /dev/null +++ b/ragflow/session.go @@ -0,0 +1,164 @@ +package ragflow + +import ( + "context" + "fmt" +) + +// Session 结构体 +type Session struct { + Id string `json:"id"` + Name string `json:"name"` + ChatId string `json:"chat_id"` // 响应中是 "chat" 或 "chat_id",根据文档示例调整 + Messages []Message `json:"messages"` + CreateDate string `json:"create_date"` + CreateTime int64 `json:"create_time"` + UpdateDate string `json:"update_date"` + UpdateTime int64 `json:"update_time"` +} + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +// CreateSessionReq 创建会话请求 +type CreateSessionReq struct { + Name string `json:"name"` + UserId string `json:"user_id,omitempty"` +} + +// ListSessionsReq 列出会话请求 +type ListSessionsReq struct { + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + OrderBy string `json:"orderby,omitempty"` + Desc bool `json:"desc,omitempty"` + Name string `json:"name,omitempty"` + Id string `json:"id,omitempty"` + UserId string `json:"user_id,omitempty"` +} + +// ListSessionsRes 列出会话响应 +type ListSessionsRes struct { + Code int `json:"code"` + Data []*Session `json:"data"` + Total int `json:"total"` // API 文档未明确 +} + +// DeleteSessionsReq 删除会话请求 +type DeleteSessionsReq struct { + Ids []string `json:"ids"` +} + +// ChatCompletionReq 对话请求 +type ChatCompletionReq struct { + Question string `json:"question"` + Stream bool `json:"stream"` + SessionId string `json:"session_id,omitempty"` + UserId string `json:"user_id,omitempty"` +} + +// ChatCompletionRes 对话响应 (非流式) +type ChatCompletionRes struct { + Code int `json:"code"` + Data struct { + Answer string `json:"answer"` + Reference interface{} `json:"reference"` + AudioBinary interface{} `json:"audio_binary"` + Id interface{} `json:"id"` + SessionId string `json:"session_id"` + } `json:"data"` +} + +// CreateSession 创建会话 +func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSessionReq) (*Session, error) { + path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId) + var res struct { + Code int `json:"code"` + Data *Session `json:"data"` + Msg string `json:"message"` + } + if err := c.request(ctx, "POST", path, req, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("create session failed: %s", res.Msg) + } + return res.Data, nil +} + +// ListSessions 列出会话 +func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessionsReq) (*ListSessionsRes, error) { + path := fmt.Sprintf("/api/v1/chats/%s/sessions?", chatId) + params := map[string]interface{}{} + if req.Page > 0 { + params["page"] = req.Page + } + if req.PageSize > 0 { + params["page_size"] = req.PageSize + } + if req.OrderBy != "" { + params["orderby"] = req.OrderBy + } + if req.Desc { + params["desc"] = "true" + } else { + params["desc"] = "false" + } + if req.Name != "" { + params["name"] = req.Name + } + if req.Id != "" { + params["id"] = req.Id + } + if req.UserId != "" { + params["user_id"] = req.UserId + } + + for k, v := range params { + path += fmt.Sprintf("%s=%v&", k, v) + } + + var res ListSessionsRes + if err := c.request(ctx, "GET", path, nil, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("list sessions failed: code=%d", res.Code) + } + return &res, nil +} + +// DeleteSessions 删除会话 +func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) error { + req := DeleteSessionsReq{Ids: ids} + var res CommonResponse + path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId) + if err := c.request(ctx, "DELETE", path, req, &res); err != nil { + return err + } + if !res.IsSuccess() { + return fmt.Errorf("delete sessions failed: %s", res.Message) + } + return nil +} + +// ChatCompletion 对话 (目前仅支持非流式) +func (c *Client) ChatCompletion(ctx context.Context, chatId string, req *ChatCompletionReq) (*ChatCompletionRes, error) { + path := fmt.Sprintf("/api/v1/chats/%s/completions", chatId) + var res ChatCompletionRes + + // 如果需要流式支持,需要使用 gclient 的流式处理能力,这里暂只实现非流式 + if req.Stream { + return nil, fmt.Errorf("stream mode not supported yet") + } + + if err := c.request(ctx, "POST", path, req, &res); err != nil { + return nil, err + } + if res.Code != 0 { + return nil, fmt.Errorf("chat completion failed: code=%d", res.Code) + } + return &res, nil +}