Files
model-gateway/common/util/mapping.go

360 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package util
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"model-gateway/model/entity"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
tgjson "github.com/tidwall/gjson"
)
// ParseAndValidate 解析并校验结果
func ParseAndValidate(raw map[string]any, model *entity.AsynchModel) (map[string]any, error) {
// 1) 解析 content 字符串为 rounds 数组
contentVal, ok := raw[model.ResponseBody]
if !ok {
return raw, fmt.Errorf("字段 %s 不存在", model.ResponseBody)
}
contentStr, ok := contentVal.(string)
if !ok || strings.TrimSpace(contentStr) == "" {
return raw, fmt.Errorf("字段 %s 为空或不是字符串", model.ResponseBody)
}
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err != nil {
return raw, fmt.Errorf("JSON解析失败: %w", err)
}
if len(arr) == 0 {
return raw, fmt.Errorf("解析后数组为空")
}
// 2) 校验必填字段
if len(model.RequiredFields) > 0 {
for i, r := range arr {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return raw, fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
}
return map[string]any{"total_rounds": len(arr), "rounds": arr}, nil
}
// ParseStructResult 解析结构结果
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
contentVal := raw[responseBody]
// 是字符串,尝试解析
contentStr := gconv.String(contentVal)
if contentStr == "" || contentStr == "0" {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: raw}},
}
}
// 尝试解析为数组
var arr []any
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: arr}},
}
}
// 尝试解析为单个对象
var parsed any
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: parsed}},
}
}
// 兜底:原始字符串作为内容
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{{responseBody: contentStr}},
}
}
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// raw 必须包含 "rounds" 字段,格式为 []map[string]any
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取 rounds
roundsRaw, ok := raw["rounds"]
if !ok {
return fmt.Errorf("缺少 rounds 字段")
}
rounds, ok := roundsRaw.([]any)
if !ok {
return fmt.Errorf("rounds 不是数组")
}
if len(rounds) == 0 {
return fmt.Errorf("rounds 数组为空")
}
// 2) 没有配置必填字段,跳过
if len(model.RequiredFields) == 0 {
return nil
}
// 3) 逐条校验
for i, r := range rounds {
round, ok := r.(map[string]any)
if !ok {
continue
}
for _, field := range model.RequiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, field)
}
}
}
return nil
}
// validateRequiredFields 校验单个 round 对象的必选字段
func validateRequiredFields(round map[string]any, requiredFields []string, prefix string) error {
for _, field := range requiredFields {
if gjson.New(round).Get(field).IsNil() {
return fmt.Errorf("%s 缺少必填字段: %s", prefix, field)
}
}
return nil
}
// ParseHeadMsgHeaders 从 head_msg JSON 中提取请求头
// head_msg 格式示例:
//
// {
// "Authorization": "Bearer xxx",
// "Content-Type": "application/json",
// "X-Api-App-Id": "5147401364",
// "X-Api-Access-Key": "VCqRX7..."
// }
func ParseHeadMsgHeaders(headMsg map[string]any) map[string]string {
if len(headMsg) == 0 {
return nil
}
out := make(map[string]string, len(headMsg))
for k, v := range headMsg {
out[k] = gconv.String(v)
}
return out
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
if len(mapping) == 0 {
return result, nil
}
// 把 result 转成 JSON 字符串tidwall/gjson 需要字符串输入
resultBytes, _ := json.Marshal(result)
resultStr := string(resultBytes)
mapped := make(map[string]any)
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
value := tgjson.Get(resultStr, path)
if !value.Exists() {
continue
}
// 如果是数组路径(含 #),取 Array否则取单值
if strings.Contains(path, "#") {
var arr []any
for _, v := range value.Array() {
arr = append(arr, v.Value())
}
mapped[standardField] = arr
} else {
mapped[standardField] = value.Value()
}
}
return mapped, nil
}
// GetModelBody 获取数据库中保存的模型信息
func GetModelBody(v map[string]any) map[string]any {
if v == nil {
return nil
}
if p, ok := v["body"]; ok {
return gconv.Map(p)
}
return v
}
// BodyToQuery 将 body 转为 url.Values
func BodyToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
}
return q, nil
}
// PullTaskResult 轮询查询异步任务结果直到完成
func PullTaskResult(ctx context.Context, body map[string]any, queryConfig map[string]any, headMsg map[string]any) (map[string]any, error) {
// 1) 解析配置
// 1.1 提取 taskID
taskIDPath := gconv.String(queryConfig["task_id"])
taskID := gconv.String(gjson.New(body).Get(taskIDPath).Val())
if taskID == "" {
return nil, fmt.Errorf("无法从路径 %s 提取 taskID", taskIDPath)
}
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s", taskID)
// 1.2 请求地址,替换 {id}
queryUrl := gconv.String(queryConfig["url"])
queryUrl = replaceURLParams(queryUrl, map[string]any{"id": taskID})
// 1.3 请求方式
method := gconv.String(queryConfig["method"])
if method == "" {
method = "GET"
}
// 1.4 状态判断配置
statusPath := gconv.String(queryConfig["status_path"])
statusValues, _ := queryConfig["status_values"].(map[string]any)
if statusPath == "" {
statusPath = "status"
}
// 1.5 轮询间隔
interval := gconv.Int(queryConfig["interval_seconds"])
if interval <= 0 {
interval = 2
}
// 1.6 请求体
reqBodyMap := map[string]any{"task_id": taskID}
// 2) 轮询请求
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
var reqBody io.Reader
if method == "POST" {
bs, _ := json.Marshal(reqBodyMap)
reqBody = bytes.NewReader(bs)
}
req, err := http.NewRequestWithContext(ctx, method, queryUrl, reqBody)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 统一用 headMsg 注入请求头
for hk, hv := range ParseHeadMsgHeaders(headMsg) {
req.Header.Set(hk, hv)
}
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
g.Log().Warningf(ctx, "[PullTaskResult] 请求失败 taskID=%s err=%v", taskID, err)
time.Sleep(time.Duration(interval) * time.Second)
continue
}
raw, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
g.Log().Infof(ctx, "[PullTaskResult] taskID=%s statusCode=%d body=%s", taskID, resp.StatusCode, string(raw))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
time.Sleep(time.Duration(interval) * time.Second)
continue
}
var result map[string]any
_ = json.Unmarshal(raw, &result)
statusVal := gjson.New(result).Get(statusPath).Val()
statusStr := gconv.String(statusVal)
g.Log().Infof(ctx, "[PullTaskResult] 状态 taskID=%s status=%v", taskID, statusVal)
if matchStatus(statusStr, statusValues["succeeded"]) {
g.Log().Infof(ctx, "[PullTaskResult] 任务成功 taskID=%s", taskID)
return result, nil
}
if matchStatus(statusStr, statusValues["failed"]) {
g.Log().Errorf(ctx, "[PullTaskResult] 任务失败 taskID=%s", taskID)
return result, fmt.Errorf("任务失败")
}
time.Sleep(time.Duration(interval) * time.Second)
}
}
func matchStatus(actual string, expected any) bool {
expectedStr := gconv.String(expected)
if actual == expectedStr {
return true
}
switch v := expected.(type) {
case []any:
for _, item := range v {
if actual == gconv.String(item) {
return true
}
}
}
return false
}
// replaceURLParams 替换 URL 中的 {key}
func replaceURLParams(url string, params map[string]any) string {
re := regexp.MustCompile(`\{([^}]+)}`)
return re.ReplaceAllStringFunc(url, func(s string) string {
key := strings.Trim(s, "{}")
if val, ok := params[key]; ok {
return gconv.String(val)
}
return s
})
}
// InjectCallbackURL 将回调地址注入到请求体中
func InjectCallbackURL(ctx context.Context, payload map[string]any, callbackURL string) map[string]any {
if callbackURL == "" {
return payload
}
payload[callbackURL] = utils.GetCallbackURL(ctx, "/task/modelCallback")
return payload
}