169 lines
4.5 KiB
Go
169 lines
4.5 KiB
Go
package ragflow
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
)
|
|
|
|
// 会话管理
|
|
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#会话管理
|
|
|
|
// Session 会话结构体
|
|
type Session struct {
|
|
Id string `json:"id"`
|
|
Name string `json:"name"`
|
|
ChatId string `json:"chat_id"` // 响应中是 "chat" 或 "chat_id",根据文档示例调整
|
|
Messages []Message `json:"messages"`
|
|
CreateDate string `json:"create_date"`
|
|
CreateTime int64 `json:"create_time"`
|
|
UpdateDate string `json:"update_date"`
|
|
UpdateTime int64 `json:"update_time"`
|
|
}
|
|
|
|
type Message struct {
|
|
Content string `json:"content"`
|
|
Role string `json:"role"`
|
|
}
|
|
|
|
// CreateSessionReq 创建会话请求
|
|
type CreateSessionReq struct {
|
|
Name string `json:"name"`
|
|
UserId string `json:"user_id,omitempty"`
|
|
}
|
|
|
|
// ListSessionsReq 列出会话请求
|
|
type ListSessionsReq struct {
|
|
Page int `json:"page,omitempty"`
|
|
PageSize int `json:"page_size,omitempty"`
|
|
OrderBy string `json:"orderby,omitempty"`
|
|
Desc bool `json:"desc,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Id string `json:"id,omitempty"`
|
|
UserId string `json:"user_id,omitempty"`
|
|
}
|
|
|
|
// ListSessionsRes 列出会话响应
|
|
type ListSessionsRes struct {
|
|
Code int `json:"code"`
|
|
Data []*Session `json:"data"`
|
|
Total int `json:"total"` // API 文档未明确
|
|
}
|
|
|
|
// DeleteSessionsReq 删除会话请求
|
|
type DeleteSessionsReq struct {
|
|
Ids []string `json:"ids"`
|
|
}
|
|
|
|
// ChatCompletionReq 对话请求
|
|
type ChatCompletionReq struct {
|
|
Question string `json:"question"`
|
|
Stream bool `json:"stream"`
|
|
SessionId string `json:"session_id,omitempty"`
|
|
UserId string `json:"user_id,omitempty"`
|
|
}
|
|
|
|
// ChatCompletionRes 对话响应 (非流式)
|
|
type ChatCompletionRes struct {
|
|
Code int `json:"code"`
|
|
Data struct {
|
|
Answer string `json:"answer"`
|
|
Reference interface{} `json:"reference"`
|
|
AudioBinary interface{} `json:"audio_binary"`
|
|
Id interface{} `json:"id"`
|
|
SessionId string `json:"session_id"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
// CreateSession 创建会话
|
|
func (c *Client) CreateSession(ctx context.Context, chatId string, req *CreateSessionReq) (*Session, error) {
|
|
path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId)
|
|
var res struct {
|
|
Code int `json:"code"`
|
|
Data *Session `json:"data"`
|
|
Msg string `json:"message"`
|
|
}
|
|
if err := c.request(ctx, "POST", path, req, &res); err != nil {
|
|
return nil, err
|
|
}
|
|
if res.Code != 0 {
|
|
return nil, fmt.Errorf("create session failed: %s", res.Msg)
|
|
}
|
|
return res.Data, nil
|
|
}
|
|
|
|
// ListSessions 列出会话
|
|
func (c *Client) ListSessions(ctx context.Context, chatId string, req *ListSessionsReq) (*ListSessionsRes, error) {
|
|
path := fmt.Sprintf("/api/v1/chats/%s/sessions?", chatId)
|
|
params := map[string]interface{}{}
|
|
if req.Page > 0 {
|
|
params["page"] = req.Page
|
|
}
|
|
if req.PageSize > 0 {
|
|
params["page_size"] = req.PageSize
|
|
}
|
|
if req.OrderBy != "" {
|
|
params["orderby"] = req.OrderBy
|
|
}
|
|
if req.Desc {
|
|
params["desc"] = "true"
|
|
} else {
|
|
params["desc"] = "false"
|
|
}
|
|
if req.Name != "" {
|
|
params["name"] = req.Name
|
|
}
|
|
if req.Id != "" {
|
|
params["id"] = req.Id
|
|
}
|
|
if req.UserId != "" {
|
|
params["user_id"] = req.UserId
|
|
}
|
|
|
|
query := buildQueryString(params)
|
|
if query != "" {
|
|
path += "?" + query
|
|
}
|
|
|
|
var res ListSessionsRes
|
|
if err := c.request(ctx, "GET", path, nil, &res); err != nil {
|
|
return nil, err
|
|
}
|
|
if res.Code != 0 {
|
|
return nil, fmt.Errorf("list sessions failed: code=%d", res.Code)
|
|
}
|
|
return &res, nil
|
|
}
|
|
|
|
// DeleteSessions 删除会话
|
|
func (c *Client) DeleteSessions(ctx context.Context, chatId string, ids []string) error {
|
|
req := DeleteSessionsReq{Ids: ids}
|
|
var res CommonResponse
|
|
path := fmt.Sprintf("/api/v1/chats/%s/sessions", chatId)
|
|
if err := c.request(ctx, "DELETE", path, req, &res); err != nil {
|
|
return err
|
|
}
|
|
if !res.IsSuccess() {
|
|
return fmt.Errorf("delete sessions failed: %s", res.Message)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ChatCompletion 对话 (目前仅支持非流式)
|
|
func (c *Client) ChatCompletion(ctx context.Context, chatId string, req *ChatCompletionReq) (*ChatCompletionRes, error) {
|
|
path := fmt.Sprintf("/api/v1/chats/%s/completions", chatId)
|
|
var res ChatCompletionRes
|
|
|
|
// 如果需要流式支持,需要使用 gclient 的流式处理能力,这里暂只实现非流式
|
|
if req.Stream {
|
|
return nil, fmt.Errorf("stream mode not supported yet")
|
|
}
|
|
|
|
if err := c.request(ctx, "POST", path, req, &res); err != nil {
|
|
return nil, err
|
|
}
|
|
if res.Code != 0 {
|
|
return nil, fmt.Errorf("chat completion failed: code=%d", res.Code)
|
|
}
|
|
return &res, nil
|
|
}
|