118 lines
3.9 KiB
Go
118 lines
3.9 KiB
Go
package ragflow
|
|
|
|
import (
|
|
"context"
|
|
|
|
"github.com/gogf/gf/v2/encoding/gjson"
|
|
"github.com/gogf/gf/v2/errors/gerror"
|
|
)
|
|
|
|
// OpenAICompatibleAPI 与 OpenAI 兼容的 API
|
|
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#与-openai-兼容的-api
|
|
|
|
// ChatCompletionMessage OpenAI 格式的消息
|
|
type ChatCompletionMessage struct {
|
|
Role string `json:"role"` // "user", "assistant", "system"
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// ChatCompletionRequest OpenAI 格式的聊天补全请求
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"` // 模型名称(服务器会自动解析,可设置为任意值)
|
|
Messages []ChatCompletionMessage `json:"messages"` // 消息列表,必须至少包含一条 user 消息
|
|
Stream bool `json:"stream,omitempty"` // 是否流式返回,默认 false
|
|
}
|
|
|
|
// ChatCompletionResponse OpenAI 格式的聊天补全响应(非流式)
|
|
type ChatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Message ChatCompletionMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
// ChatCompletionChunk 流式响应块
|
|
type ChatCompletionChunk struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
Role string `json:"role"`
|
|
} `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage *struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage,omitempty"`
|
|
}
|
|
|
|
// CreateChatCompletion 创建聊天补全(与聊天助手)
|
|
// POST /api/v1/chats_openai/{chat_id}/chat/completions
|
|
func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
|
path := "/api/v1/chats_openai/" + chatID + "/chat/completions"
|
|
|
|
var resp ChatCompletionResponse
|
|
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
|
|
return nil, gerror.Newf("create chat completion failed: %v", err)
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
// CreateAgentCompletion 创建 Agent 补全
|
|
// POST /api/v1/agents_openai/{agent_id}/chat/completions
|
|
func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
|
path := "/api/v1/agents_openai/" + agentID + "/chat/completions"
|
|
|
|
var resp ChatCompletionResponse
|
|
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
|
|
return nil, gerror.Newf("create agent completion failed: %v", err)
|
|
}
|
|
|
|
return &resp, nil
|
|
}
|
|
|
|
// CreateChatCompletionStream 创建流式聊天补全(与聊天助手)
|
|
// 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口
|
|
func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) {
|
|
req.Stream = true
|
|
// TODO: 实现流式读取逻辑
|
|
return nil, gerror.New("stream mode not implemented yet")
|
|
}
|
|
|
|
// StreamReader 流式响应读取器
|
|
type StreamReader struct {
|
|
_ *gjson.Json // TODO: 实现流式读取时使用
|
|
close func() error
|
|
}
|
|
|
|
// ReadChunk 读取下一个响应块
|
|
// TODO: 实现流式读取逻辑
|
|
func (sr *StreamReader) ReadChunk() (*ChatCompletionChunk, error) {
|
|
return nil, gerror.New("stream mode not implemented yet")
|
|
}
|
|
|
|
// Close 关闭流
|
|
func (sr *StreamReader) Close() (err error) {
|
|
if sr.close != nil {
|
|
return sr.close()
|
|
}
|
|
return
|
|
}
|