package service import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "strings" "time" "model-asynch/model/entity" ) func parseAPIKeyHeader(apiKey string) (k, v string) { apiKey = strings.TrimSpace(apiKey) if apiKey == "" { return "", "" } // 支持两种写法: // 1) HeaderName:HeaderValue(推荐) // 2) HeaderName=HeaderValue(兼容) if strings.Contains(apiKey, ":") { parts := strings.SplitN(apiKey, ":", 2) return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) } if strings.Contains(apiKey, "=") { parts := strings.SplitN(apiKey, "=", 2) return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]) } // 只给了 value:不做注入(避免注入非法 header) return "", "" } func payloadToQuery(payload any) (url.Values, error) { if payload == nil { return url.Values{}, nil } // 统一转成 map[string]any b, err := json.Marshal(payload) if err != nil { return nil, err } m := map[string]any{} if err := json.Unmarshal(b, &m); err != nil { return nil, err } q := url.Values{} for k, v := range m { if v == nil { continue } // 复杂类型直接 json 字符串化 switch vv := v.(type) { case string: q.Set(k, vv) case float64, bool, int, int64, uint64: q.Set(k, fmt.Sprintf("%v", vv)) default: bs, _ := json.Marshal(v) q.Set(k, string(bs)) } } return q, nil } // InvokeModel 调用模型服务,返回二进制结果 func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byte, error) { if m == nil || m.BaseURL == "" { return nil, fmt.Errorf("模型配置不完整") } url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/") if strings.TrimSpace(m.Route) == "" { url = strings.TrimRight(m.BaseURL, "/") } timeout := time.Duration(m.TimeoutMs) * time.Millisecond if timeout <= 0 { timeout = 60 * time.Second } client := &http.Client{Timeout: timeout} method := strings.ToUpper(strings.TrimSpace(m.HttpMethod)) if method == "" { method = http.MethodPost } var ( req *http.Request err error ) switch method { case http.MethodGet: q, err := payloadToQuery(payload) if err != nil { return nil, err } if len(q) > 0 { if strings.Contains(url, "?") { url = url + "&" + q.Encode() } else { url = url + "?" + q.Encode() } } req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil) default: bodyBytes, err := json.Marshal(payload) if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes)) } if err != nil { return nil, err } // 透传必要头部(如 Authorization / X-User-Info),以及注入模型配置里的 api_key for k, v := range forwardHeaders(ctx) { if v != "" { req.Header.Set(k, v) } } if hk, hv := parseAPIKeyHeader(m.APIKey); hk != "" && hv != "" { req.Header.Set(hk, hv) } if method != http.MethodGet { req.Header.Set("Content-Type", "application/json") } resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() b, err := io.ReadAll(resp.Body) if err != nil { return nil, err } if resp.StatusCode < 200 || resp.StatusCode >= 300 { // 尽量把错误体带回去,方便排查 msg := string(b) if len(msg) > 2000 { msg = msg[:2000] } return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg) } return b, nil }