Files
common/ragflow/openai.go

118 lines
3.9 KiB
Go

package ragflow
import (
"context"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/errors/gerror"
)
// OpenAICompatibleAPI 与 OpenAI 兼容的 API
// 参考: https://ragflow.com.cn/docs/dev/http_api_reference#与-openai-兼容的-api
// ChatCompletionMessage OpenAI 格式的消息
type ChatCompletionMessage struct {
Role string `json:"role"` // "user", "assistant", "system"
Content string `json:"content"`
}
// ChatCompletionRequest OpenAI 格式的聊天补全请求
type ChatCompletionRequest struct {
Model string `json:"model"` // 模型名称(服务器会自动解析,可设置为任意值)
Messages []ChatCompletionMessage `json:"messages"` // 消息列表,必须至少包含一条 user 消息
Stream bool `json:"stream,omitempty"` // 是否流式返回,默认 false
}
// ChatCompletionResponse OpenAI 格式的聊天补全响应(非流式)
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message ChatCompletionMessage `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
// ChatCompletionChunk 流式响应块
type ChatCompletionChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
Role string `json:"role"`
} `json:"delta"`
FinishReason *string `json:"finish_reason"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage,omitempty"`
}
// CreateChatCompletion 创建聊天补全(与聊天助手)
// POST /api/v1/chats_openai/{chat_id}/chat/completions
func (c *Client) CreateChatCompletion(ctx context.Context, chatID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
path := "/api/v1/chats_openai/" + chatID + "/chat/completions"
var resp ChatCompletionResponse
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
return nil, gerror.Newf("create chat completion failed: %v", err)
}
return &resp, nil
}
// CreateAgentCompletion 创建 Agent 补全
// POST /api/v1/agents_openai/{agent_id}/chat/completions
func (c *Client) CreateAgentCompletion(ctx context.Context, agentID string, req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
path := "/api/v1/agents_openai/" + agentID + "/chat/completions"
var resp ChatCompletionResponse
if err := c.request(ctx, "POST", path, req, &resp); err != nil {
return nil, gerror.Newf("create agent completion failed: %v", err)
}
return &resp, nil
}
// CreateChatCompletionStream 创建流式聊天补全(与聊天助手)
// 注意:流式响应需要特殊处理,这里返回一个可用于读取流的接口
func (c *Client) CreateChatCompletionStream(ctx context.Context, chatID string, req *ChatCompletionRequest) (*StreamReader, error) {
req.Stream = true
// TODO: 实现流式读取逻辑
return nil, gerror.New("stream mode not implemented yet")
}
// StreamReader 流式响应读取器
type StreamReader struct {
_ *gjson.Json // TODO: 实现流式读取时使用
close func() error
}
// ReadChunk 读取下一个响应块
// TODO: 实现流式读取逻辑
func (sr *StreamReader) ReadChunk() (*ChatCompletionChunk, error) {
return nil, gerror.New("stream mode not implemented yet")
}
// Close 关闭流
func (sr *StreamReader) Close() (err error) {
if sr.close != nil {
return sr.close()
}
return
}