Files
model-asynch/service/model_invoker.go
2026-04-23 13:53:09 +08:00

152 lines
3.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-asynch/model/entity"
)
func parseAPIKeyHeader(apiKey string) (k, v string) {
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return "", ""
}
// 支持两种写法:
// 1) HeaderName:HeaderValue推荐
// 2) HeaderName=HeaderValue兼容
if strings.Contains(apiKey, ":") {
parts := strings.SplitN(apiKey, ":", 2)
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
}
if strings.Contains(apiKey, "=") {
parts := strings.SplitN(apiKey, "=", 2)
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
}
// 只给了 value不做注入避免注入非法 header
return "", ""
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
if strings.TrimSpace(m.Route) == "" {
url = strings.TrimRight(m.BaseURL, "/")
}
timeout := time.Duration(m.TimeoutMs) * time.Millisecond
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
err error
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(payload)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
// 透传必要头部(如 Authorization / X-User-Info以及注入模型配置里的 api_key
for k, v := range forwardHeaders(ctx) {
if v != "" {
req.Header.Set(k, v)
}
}
if hk, hv := parseAPIKeyHeader(m.APIKey); hk != "" && hv != "" {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// 尽量把错误体带回去,方便排查
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
}