Files
data-engine/service/sync/api_client.go

534 lines
13 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sync
import (
"bytes"
"context"
"crypto/md5"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"sort"
"strings"
"time"
"github.com/sirupsen/logrus"
)
// ApiResult API 调用结果
type ApiResult struct {
Body []byte
DurationMs int64
}
// ApiClient 通用 API 客户端
type ApiClient struct {
config *PlatformConfig
client *http.Client
rateLimiter *time.Ticker // 限流 ticker可被 GC
}
// NewApiClient 创建客户端
func NewApiClient(config *PlatformConfig) *ApiClient {
timeout := 30 * time.Second
if config.RequestTimeoutMs > 0 {
timeout = time.Duration(config.RequestTimeoutMs) * time.Millisecond
}
transport := &http.Transport{
MaxIdleConns: 100,
MaxIdleConnsPerHost: 20,
IdleConnTimeout: 90 * time.Second,
}
ac := &ApiClient{
config: config,
client: &http.Client{
Timeout: timeout,
Transport: transport,
},
}
// 初始化限流
if config.RateLimitPerMinute > 0 {
interval := time.Minute / time.Duration(config.RateLimitPerMinute)
ac.rateLimiter = time.NewTicker(interval)
logrus.Infof("限流已启用: %d 次/分钟, 间隔 %v", config.RateLimitPerMinute, interval)
}
return ac
}
// Get 发送 GET 请求(无参数)
func (c *ApiClient) Get(ctx context.Context, path string) (*ApiResult, error) {
return c.doRequest(ctx, "GET", path, nil, false)
}
// PostJSON 发送 POST JSON 请求
func (c *ApiClient) PostJSON(ctx context.Context, path string, body interface{}) (*ApiResult, error) {
return c.doRequest(ctx, "POST", path, body, false)
}
// Close 释放客户端资源(限流 ticker
func (c *ApiClient) Close() {
if c.rateLimiter != nil {
c.rateLimiter.Stop()
}
}
// Request 通用请求方法(支持 GET/POST支持参数在 query 或 body
func (c *ApiClient) Request(ctx context.Context, method, path string, params map[string]interface{}, paramsInQuery bool) (*ApiResult, error) {
return c.doRequest(ctx, method, path, params, paramsInQuery)
}
func (c *ApiClient) doRequest(ctx context.Context, method, path string, body interface{}, paramsInQuery bool) (result *ApiResult, err error) {
maxRetries := c.config.MaxRetries
if maxRetries <= 0 {
maxRetries = 3
}
retryDelay := time.Duration(c.config.RetryDelayMs) * time.Millisecond
if retryDelay <= 0 {
retryDelay = 1 * time.Second
}
for attempt := 0; attempt <= maxRetries; attempt++ {
result, err = c.execute(ctx, method, path, body, paramsInQuery)
if err == nil {
return result, nil
}
logrus.Warnf("请求失败 (attempt %d/%d): %v", attempt+1, maxRetries+1, err)
if attempt < maxRetries {
time.Sleep(retryDelay * time.Duration(attempt+1))
}
}
return result, fmt.Errorf("请求已重试 %d 次仍失败: %w", maxRetries, err)
}
func (c *ApiClient) execute(ctx context.Context, method, path string, body interface{}, paramsInQuery bool) (*ApiResult, error) {
// 限流等待
if c.rateLimiter != nil {
select {
case <-c.rateLimiter.C:
case <-ctx.Done():
c.rateLimiter.Stop()
return nil, ctx.Err()
}
}
start := time.Now()
fullURL := c.config.GetApiUrl(path)
// 先注入认证参数到 URL
fullURL = c.applyAuthURL(fullURL)
// 将 URL 认证参数注入 body 并清除 URL避免重复参数
var reqBody io.Reader
var reqBodyBytes []byte
if body != nil && !paramsInQuery {
if paramsMap, ok := body.(map[string]interface{}); ok {
// 从 URL 注入认证参数到 body
if parsed, _ := url.Parse(fullURL); parsed != nil {
q := parsed.Query()
for k, vs := range q {
if len(vs) > 0 {
if _, exists := paramsMap[k]; !exists {
paramsMap[k] = vs[0]
}
q.Del(k)
}
}
parsed.RawQuery = q.Encode()
fullURL = parsed.String()
}
// Form body
formStr := c.buildFormBody(paramsMap)
reqBodyBytes = []byte(formStr)
reqBody = strings.NewReader(formStr)
} else {
b, err := json.Marshal(body)
if err != nil {
return nil, fmt.Errorf("JSON序列化请求体失败: %w", err)
}
reqBodyBytes = b
reqBody = bytes.NewBuffer(b)
}
}
// GET query 模式
if body != nil && paramsInQuery {
if paramsMap, ok := body.(map[string]interface{}); ok {
fullURL = c.buildQueryURL(fullURL, paramsMap)
}
}
// 计算固定签名
fullURL = c.applySignature(fullURL, body, paramsInQuery)
// 将 sign 注入 body 并从 URL 清除
if !paramsInQuery && reqBodyBytes != nil {
if parsed, _ := url.Parse(fullURL); parsed != nil {
if signVal := parsed.Query().Get("sign"); signVal != "" {
reqBodyBytes = append(reqBodyBytes, []byte("&sign="+signVal)...)
reqBody = bytes.NewReader(reqBodyBytes)
q := parsed.Query()
q.Del("sign")
parsed.RawQuery = q.Encode()
fullURL = parsed.String()
}
}
}
// 打印等效 curl
curlCmd := fmt.Sprintf("curl -X %s '%s'", method, fullURL)
if reqBodyBytes != nil && len(reqBodyBytes) > 0 {
for _, pair := range strings.Split(string(reqBodyBytes), "&") {
if pair != "" {
curlCmd += fmt.Sprintf(" --data-urlencode '%s'", pair)
}
}
}
logrus.Infof("等效curl: %s", curlCmd)
req, err := http.NewRequestWithContext(ctx, method, fullURL, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
c.applyAuthHeader(req, reqBodyBytes)
req.Header.Set("User-Agent", "data-engine/1.0")
if body != nil && !paramsInQuery {
if _, ok := body.(map[string]interface{}); ok {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} else {
req.Header.Set("Content-Type", "application/json")
}
}
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应体失败: %w", err)
}
result := &ApiResult{Body: respBody, DurationMs: time.Since(start).Milliseconds()}
if resp.StatusCode >= 400 {
return result, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(respBody))
}
return result, nil
}
// buildQueryURL 将 params 拼接到 URL 查询参数中
// 支持数组/对象类型的值自动 JSON 序列化 + URL 编码
func (c *ApiClient) buildQueryURL(rawURL string, params map[string]interface{}) string {
parsed, err := url.Parse(rawURL)
if err != nil || parsed == nil {
logrus.Errorf("buildQueryURL: 解析 URL 失败: %v", err)
return rawURL
}
q := parsed.Query()
for k, v := range params {
switch val := v.(type) {
case string:
q.Set(k, val)
case bool:
if val {
q.Set(k, "true")
} else {
q.Set(k, "false")
}
case float64:
// JSON 数字反序列化默认是 float64转 int 避免科学计数法
if val == float64(int64(val)) {
q.Set(k, fmt.Sprintf("%d", int64(val)))
} else {
q.Set(k, fmt.Sprintf("%v", val))
}
case float32:
q.Set(k, fmt.Sprintf("%v", val))
case int, int8, int16, int32, int64:
q.Set(k, fmt.Sprintf("%d", val))
case uint, uint8, uint16, uint32, uint64:
q.Set(k, fmt.Sprintf("%d", val))
case []interface{}, map[string]interface{}:
// 数组或对象需要 JSON 序列化后 URL 编码
b, _ := json.Marshal(v)
q.Set(k, string(b))
default:
q.Set(k, fmt.Sprintf("%v", v))
}
}
parsed.RawQuery = q.Encode()
return parsed.String()
}
// buildFormBody 将 params 编码为 application/x-www-form-urlencoded 字符串
func (c *ApiClient) buildFormBody(params map[string]interface{}) string {
q := make(url.Values)
for k, v := range params {
switch val := v.(type) {
case string:
q.Set(k, val)
case float64:
if val == float64(int64(val)) {
q.Set(k, fmt.Sprintf("%d", int64(val)))
} else {
q.Set(k, fmt.Sprintf("%v", val))
}
case int, int8, int16, int32, int64:
q.Set(k, fmt.Sprintf("%d", val))
default:
q.Set(k, fmt.Sprintf("%v", v))
}
}
return q.Encode()
}
func (c *ApiClient) applyAuthURL(rawURL string) string {
cfg := c.config.AuthConfig
token := c.config.AccessToken
if cfg == nil {
return rawURL
}
tokenInQuery, _ := cfg["token_in_query"].(bool)
queryKey, _ := cfg["query_key"].(string)
if queryKey == "" {
queryKey = "access_token"
}
extraParams := make(map[string]string)
if eq, ok := cfg["extra_query_params"].(map[string]interface{}); ok {
for k, v := range eq {
val := fmt.Sprintf("%v", v)
val = strings.ReplaceAll(val, "{timestamp}", fmt.Sprintf("%d", time.Now().Unix()))
val = strings.ReplaceAll(val, "{timestamp_ms}", fmt.Sprintf("%d", time.Now().UnixMilli()))
val = strings.ReplaceAll(val, "{nonce}", generateNonce())
extraParams[k] = val
}
}
if !tokenInQuery && len(extraParams) == 0 {
return rawURL
}
parsed, err := url.Parse(rawURL)
if err != nil || parsed == nil {
logrus.Errorf("applyAuthURL: 解析 URL 失败: %v", err)
return rawURL
}
q := parsed.Query()
if tokenInQuery && token != "" {
q.Set(queryKey, token)
}
for k, v := range extraParams {
q.Set(k, v)
}
// 注入 appkey
if appKey, ok := cfg["app_key"].(string); ok && appKey != "" {
q.Set("appkey", appKey)
}
parsed.RawQuery = q.Encode()
return parsed.String()
}
func (c *ApiClient) applyAuthHeader(req *http.Request, bodyBytes []byte) {
cfg := c.config.AuthConfig
token := c.config.AccessToken
// APP_SIGNATURE 认证app-id + signature 头部(如钉钉智能薪酬)
if c.config.AuthType == "APP_SIGNATURE" {
c.applyAppSignatureAuth(req, bodyBytes)
return
}
if cfg != nil {
if tiq, _ := cfg["token_in_query"].(bool); tiq {
return
}
}
if token == "" {
return
}
if cfg != nil {
if h, ok := cfg["header_name"].(string); ok {
f := "{token}"
if fv, ok2 := cfg["header_format"].(string); ok2 {
f = fv
}
req.Header.Set(h, strings.ReplaceAll(f, "{token}", token))
return
}
}
switch c.config.AuthType {
case "OAUTH2", "TOKEN":
req.Header.Set("Authorization", "Bearer "+token)
case "API_KEY":
req.Header.Set("X-API-Key", token)
}
}
// applyAppSignatureAuth 设置 app-id + signature 认证头部
func (c *ApiClient) applyAppSignatureAuth(req *http.Request, bodyBytes []byte) {
cfg := c.config.AuthConfig
if cfg == nil {
return
}
// 1. 设置 app-id 头部
appIdHeader := "app-id"
if h, _ := cfg["app_id_header"].(string); h != "" {
appIdHeader = h
}
appId := c.config.AppKey
if appId == "" {
if aid, _ := cfg["app_id"].(string); aid != "" {
appId = aid
}
}
if appId != "" {
req.Header.Set(appIdHeader, appId)
}
// 2. 计算签名并设置 signature 头部
signHeader := "signature"
if h, _ := cfg["sign_header"].(string); h != "" {
signHeader = h
}
secret := c.config.AppSecret
signAlgo := "md5_upper_body"
if a, _ := cfg["sign_algorithm"].(string); a != "" {
signAlgo = a
}
sig := computeBodySignature(bodyBytes, secret, signAlgo)
if sig != "" {
req.Header.Set(signHeader, sig)
}
}
// computeBodySignature 计算基于请求体的签名
// 支持的算法:
// - md5_upper_body: MD5(body_string + secret) 大写(默认,钉钉智能薪酬)
// - md5_body: MD5(body_string + secret) 小写
func computeBodySignature(bodyBytes []byte, secret, algo string) string {
if secret == "" {
return ""
}
bodyStr := ""
if len(bodyBytes) > 0 {
bodyStr = string(bodyBytes)
}
switch algo {
case "md5_body", "md5_upper_body":
h := md5.Sum([]byte(bodyStr + secret))
sig := hex.EncodeToString(h[:])
if algo == "md5_upper_body" {
sig = strings.ToUpper(sig)
}
return sig
default:
logrus.Warnf("未知签名算法: %s", algo)
return ""
}
}
func generateNonce() string {
nanoPart := time.Now().UnixNano() % 1000000000000
r, _ := rand.Int(rand.Reader, big.NewInt(10000))
return fmt.Sprintf("%012d%04d", nanoPart, r.Int64())
}
// applySignature 计算签名并追加到 URL
// 快手签名: 字母序拼接 key=value&...&signSecret=<secret>, 取 MD5
func (c *ApiClient) applySignature(rawURL string, body interface{}, paramsInQuery bool) string {
cfg := c.config.AuthConfig
if cfg == nil {
return rawURL
}
signAlgo, _ := cfg["sign_algorithm"].(string)
if signAlgo == "" {
return rawURL
}
// 获取 signSecret签名专用密钥
signSecret, _ := cfg["sign_secret"].(string)
if signSecret == "" {
signSecret, _ = cfg["app_secret"].(string)
}
if signSecret == "" && c.config.AppSecret != "" {
signSecret = c.config.AppSecret
}
if signSecret == "" {
return rawURL
}
parsed, err := url.Parse(rawURL)
if err != nil || parsed == nil {
logrus.Errorf("applySignature: 解析 URL 失败: %v", err)
return rawURL
}
q := parsed.Query()
// POST: 合并 body 参数
if !paramsInQuery {
if bodyMap, ok := body.(map[string]interface{}); ok {
for k, v := range bodyMap {
q.Set(k, fmt.Sprintf("%v", v))
}
}
}
// 收集参数(排除 sign按 key 排序
keys := make([]string, 0, len(q))
for k := range q {
if k == "sign" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
// 拼接: key1=value1&key2=value2&...
var signStr string
for i, k := range keys {
if i > 0 {
signStr += "&"
}
signStr += k + "=" + q.Get(k)
}
// 追加 signSecret
signStr += "&signSecret=" + signSecret
logrus.Infof("签名原文: %s", signStr)
// 计算签名
var sign string
switch signAlgo {
case "md5":
h := md5.Sum([]byte(signStr))
sign = hex.EncodeToString(h[:])
case "md5_upper":
h := md5.Sum([]byte(signStr))
sign = strings.ToUpper(hex.EncodeToString(h[:]))
default:
return rawURL
}
logrus.Infof("签名值 sign=%s", sign)
q.Set("sign", sign)
parsed.RawQuery = q.Encode()
return parsed.String()
}