144 lines
3.5 KiB
Go
144 lines
3.5 KiB
Go
package ragflow
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/net/gclient"
|
||
)
|
||
|
||
var (
|
||
// globalClient 全局 RAGFlow 客户端(单例,自动初始化)
|
||
globalClient *Client
|
||
)
|
||
|
||
// init 包初始化时自动创建全局客户端
|
||
func init() {
|
||
ctx := context.Background()
|
||
|
||
// 读取配置
|
||
baseURL, apiKey := loadConfig(ctx)
|
||
|
||
// 如果配置不完整,跳过初始化
|
||
if baseURL == "" || apiKey == "" {
|
||
g.Log().Warning(ctx, "⚠️ RAGFlow 配置未找到,请在项目 config.yml 中添加 ragflow.base_url 和 ragflow.api_key")
|
||
return
|
||
}
|
||
|
||
// 初始化全局客户端
|
||
httpClient := gclient.New()
|
||
httpClient.SetHeader("Authorization", fmt.Sprintf("Bearer %s", apiKey))
|
||
httpClient.SetHeader("Content-Type", "application/json")
|
||
|
||
globalClient = &Client{
|
||
BaseURL: strings.TrimSuffix(baseURL, "/"),
|
||
APIKey: apiKey,
|
||
HTTPClient: httpClient,
|
||
}
|
||
|
||
g.Log().Infof(ctx, "✅ RAGFlow 全局客户端初始化成功: baseURL=%s", baseURL)
|
||
}
|
||
|
||
// loadConfig 从配置文件加载 RAGFlow 配置
|
||
func loadConfig(ctx context.Context) (baseURL, apiKey string) {
|
||
// 使用 GoFrame 全局配置(从项目的 config.yml 读取)
|
||
baseURL = g.Cfg().MustGet(ctx, "ragflow.base_url", "").String()
|
||
apiKey = g.Cfg().MustGet(ctx, "ragflow.api_key", "").String()
|
||
|
||
return baseURL, apiKey
|
||
}
|
||
|
||
// GetGlobalClient 获取全局客户端
|
||
// 使用示例:client := ragflow.GetGlobalClient()
|
||
func GetGlobalClient() *Client {
|
||
return globalClient
|
||
}
|
||
|
||
// Client RAGFlow API 客户端
|
||
type Client struct {
|
||
BaseURL string
|
||
APIKey string
|
||
HTTPClient *gclient.Client // HTTP 客户端
|
||
}
|
||
|
||
// CommonResponse 通用响应结构
|
||
type CommonResponse struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Data interface{} `json:"data,omitempty"`
|
||
}
|
||
|
||
// IsSuccess 检查响应是否成功
|
||
func (r *CommonResponse) IsSuccess() bool {
|
||
return r.Code == 0
|
||
}
|
||
|
||
// request 发送 HTTP 请求
|
||
func (c *Client) request(ctx context.Context, method, path string, body interface{}, result interface{}) error {
|
||
fullURL := c.BaseURL + path
|
||
|
||
var reqBody io.Reader
|
||
if body != nil {
|
||
jsonData, err := json.Marshal(body)
|
||
if err != nil {
|
||
return fmt.Errorf("marshal request body failed: %w", err)
|
||
}
|
||
reqBody = strings.NewReader(string(jsonData))
|
||
}
|
||
|
||
var resp *gclient.Response
|
||
var err error
|
||
|
||
switch method {
|
||
case "GET":
|
||
resp, err = c.HTTPClient.Get(ctx, fullURL)
|
||
case "POST":
|
||
resp, err = c.HTTPClient.Post(ctx, fullURL, reqBody)
|
||
case "PUT":
|
||
resp, err = c.HTTPClient.Put(ctx, fullURL, reqBody)
|
||
case "DELETE":
|
||
resp, err = c.HTTPClient.Delete(ctx, fullURL, reqBody)
|
||
default:
|
||
return fmt.Errorf("unsupported method: %s", method)
|
||
}
|
||
|
||
if err != nil {
|
||
return fmt.Errorf("http request failed: %w", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("http request failed with status: %d", resp.StatusCode)
|
||
}
|
||
|
||
respBody := resp.ReadAll()
|
||
if err != nil {
|
||
return fmt.Errorf("read response body failed: %w", err)
|
||
}
|
||
|
||
if err := json.Unmarshal(respBody, result); err != nil {
|
||
return fmt.Errorf("unmarshal response failed: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// buildQueryString 构建查询字符串
|
||
func buildQueryString(params map[string]interface{}) string {
|
||
if len(params) == 0 {
|
||
return ""
|
||
}
|
||
|
||
var parts []string
|
||
for k, v := range params {
|
||
parts = append(parts, fmt.Sprintf("%s=%v", url.QueryEscape(k), url.QueryEscape(fmt.Sprintf("%v", v))))
|
||
}
|
||
return strings.Join(parts, "&")
|
||
}
|