144 lines
3.7 KiB
Go
144 lines
3.7 KiB
Go
package util
|
||
|
||
import (
|
||
"fmt"
|
||
"net/url"
|
||
"prompts-core/model/entity"
|
||
"strings"
|
||
|
||
"github.com/gogf/gf/v2/encoding/gjson"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||
// 1) 获取校验配置,并取值
|
||
requestMapping := model.RequestMapping
|
||
contentStr, ok := raw[model.ResponseBody].(string)
|
||
if !ok || contentStr == "" {
|
||
return fmt.Errorf("%s 字段为空或不是字符串", model.ResponseBody)
|
||
}
|
||
|
||
// 2) 解析 content 为 JSON 数组
|
||
var rounds []map[string]any
|
||
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||
}
|
||
if len(rounds) == 0 {
|
||
return fmt.Errorf("content 数组为空")
|
||
}
|
||
|
||
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
||
for i, round := range rounds {
|
||
for path, defaultValue := range requestMapping {
|
||
if !g.IsEmpty(defaultValue) {
|
||
continue
|
||
}
|
||
if gjson.New(round).Get(path).IsNil() {
|
||
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ReverseMap 映射 payload 到 mapping
|
||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||
jsonObj := gjson.New("{}")
|
||
for path, defaultValue := range mapping {
|
||
val := gjson.New(payload).Get(path)
|
||
if !val.IsNil() {
|
||
_ = jsonObj.Set(path, val.Val())
|
||
} else if defaultValue != nil {
|
||
_ = jsonObj.Set(path, defaultValue)
|
||
}
|
||
}
|
||
return jsonObj.Map()
|
||
}
|
||
|
||
// MapResponsePayload 映射模型响应为标准格式
|
||
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||
if len(mapping) == 0 {
|
||
return responseBytes, nil
|
||
}
|
||
|
||
responseJson := gjson.New(responseBytes)
|
||
resultJson := gjson.New("{}")
|
||
|
||
for standardField, modelPath := range mapping {
|
||
path := gconv.String(modelPath)
|
||
if path == "" {
|
||
continue
|
||
}
|
||
val := responseJson.Get(path)
|
||
if val.IsNil() {
|
||
continue
|
||
}
|
||
resultJson.Set(standardField, val.Val())
|
||
}
|
||
|
||
return []byte(resultJson.String()), nil
|
||
}
|
||
|
||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||
// 示例:
|
||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||
//
|
||
// 说明:
|
||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||
headMsg = strings.TrimSpace(headMsg)
|
||
if headMsg == "" {
|
||
return nil
|
||
}
|
||
out := map[string]string{}
|
||
parts := strings.Split(headMsg, ",")
|
||
for _, p := range parts {
|
||
p = strings.TrimSpace(p)
|
||
if p == "" {
|
||
continue
|
||
}
|
||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||
if strings.Contains(p, ":") {
|
||
kv := strings.SplitN(p, ":", 2)
|
||
k := strings.TrimSpace(kv[0])
|
||
v := strings.TrimSpace(kv[1])
|
||
v = strings.Trim(v, "\"")
|
||
if k != "" && v != "" {
|
||
out[k] = v
|
||
}
|
||
continue
|
||
}
|
||
if strings.Contains(p, "=") {
|
||
kv := strings.SplitN(p, "=", 2)
|
||
k := strings.TrimSpace(kv[0])
|
||
v := strings.TrimSpace(kv[1])
|
||
v = strings.Trim(v, "\"")
|
||
if k != "" && v != "" {
|
||
out[k] = v
|
||
}
|
||
continue
|
||
}
|
||
}
|
||
if len(out) == 0 {
|
||
return nil
|
||
}
|
||
return out
|
||
}
|
||
|
||
// PayloadToQuery 将 payload 转为 url.Values
|
||
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||
q := url.Values{}
|
||
for k, v := range payload {
|
||
if v == nil {
|
||
continue
|
||
}
|
||
q.Set(k, gconv.String(v))
|
||
}
|
||
return q, nil
|
||
}
|