feat(model): 添加流式配置支持并优化响应处理
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"model-gateway/model/entity"
|
||||
"net/url"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
tgjson "github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||||
@@ -67,27 +69,40 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
}
|
||||
|
||||
// MapResponsePayload 映射模型响应为标准格式
|
||||
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||
func MapResponsePayload(mapping map[string]any, result map[string]any) (map[string]any, error) {
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
responseJson := gjson.New(responseBytes)
|
||||
resultJson := gjson.New("{}")
|
||||
// 把 result 转成 JSON 字符串,tidwall/gjson 需要字符串输入
|
||||
resultBytes, _ := json.Marshal(result)
|
||||
resultStr := string(resultBytes)
|
||||
|
||||
mapped := make(map[string]any)
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
path := gconv.String(modelPath)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
val := responseJson.Get(path)
|
||||
if val.IsNil() {
|
||||
|
||||
value := tgjson.Get(resultStr, path)
|
||||
if !value.Exists() {
|
||||
continue
|
||||
}
|
||||
resultJson.Set(standardField, val.Val())
|
||||
// 如果是数组路径(含 #),取 Array;否则取单值
|
||||
if strings.Contains(path, "#") {
|
||||
var arr []any
|
||||
for _, v := range value.Array() {
|
||||
arr = append(arr, v.Value())
|
||||
}
|
||||
mapped[standardField] = arr
|
||||
} else {
|
||||
mapped[standardField] = value.Value()
|
||||
}
|
||||
}
|
||||
|
||||
return []byte(resultJson.String()), nil
|
||||
return mapped, nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
|
||||
150
common/util/streaming.go
Normal file
150
common/util/streaming.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
)
|
||||
|
||||
// ================================================================
|
||||
|
||||
// ParseStreamResponse 流式响应解析(通用入口)
|
||||
func ParseStreamResponse(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
|
||||
enabled, _ := streamConfig["enabled"].(bool)
|
||||
if !enabled {
|
||||
return gjson.New(string(rawBytes)).Map(), nil
|
||||
}
|
||||
|
||||
parser, _ := streamConfig["parser"].(string)
|
||||
if parser == "base64_concat" {
|
||||
return parseBase64Stream(rawBytes)
|
||||
}
|
||||
|
||||
return parseSSEStream(rawBytes, streamConfig)
|
||||
}
|
||||
|
||||
// parseBase64Stream 拼接流式 base64 并解码为二进制(TTS 等音频模型)
|
||||
func parseBase64Stream(rawBytes []byte) (map[string]any, error) {
|
||||
lines := strings.Split(string(rawBytes), "\n")
|
||||
var audioBase64 strings.Builder
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if data, ok := chunk["data"].(string); ok && data != "" {
|
||||
audioBase64.WriteString(data)
|
||||
}
|
||||
}
|
||||
|
||||
cleanBase64 := strings.Map(func(r rune) rune {
|
||||
if r == ' ' || r == '\n' || r == '\r' || r == '\t' {
|
||||
return -1
|
||||
}
|
||||
return r
|
||||
}, audioBase64.String())
|
||||
|
||||
audioBytes, err := base64.StdEncoding.DecodeString(cleanBase64)
|
||||
if err != nil {
|
||||
audioBytes, err = base64.RawStdEncoding.DecodeString(cleanBase64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("base64 解码失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]any{"audio": audioBytes}, nil
|
||||
}
|
||||
|
||||
// parseSSEStream SSE 流式解析(图片模型等)
|
||||
func parseSSEStream(rawBytes []byte, streamConfig map[string]any) (map[string]any, error) {
|
||||
events, _ := streamConfig["events"].([]any)
|
||||
if len(events) == 0 {
|
||||
return gjson.New(string(rawBytes)).Map(), nil
|
||||
}
|
||||
|
||||
lines := strings.Split(string(rawBytes), "\n")
|
||||
result := make(map[string]any)
|
||||
var partials []map[string]any
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || line == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "event:") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "data:") {
|
||||
line = strings.TrimPrefix(line, "data:")
|
||||
line = strings.TrimSpace(line)
|
||||
}
|
||||
|
||||
var chunk map[string]any
|
||||
if err := json.Unmarshal([]byte(line), &chunk); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
chunkType, _ := chunk["type"].(string)
|
||||
|
||||
for _, evt := range events {
|
||||
e, _ := evt.(map[string]any)
|
||||
match, _ := e["match"].(string)
|
||||
if !strings.Contains(chunkType, match) {
|
||||
continue
|
||||
}
|
||||
|
||||
fields, _ := e["fields"].(map[string]any)
|
||||
aggregateTo, _ := e["aggregate_to"].(string)
|
||||
evtType, _ := e["type"].(string)
|
||||
|
||||
switch evtType {
|
||||
case "partial":
|
||||
item := make(map[string]any)
|
||||
for localKey, chunkKey := range fields {
|
||||
item[localKey] = chunk[chunkKey.(string)]
|
||||
}
|
||||
partials = append(partials, item)
|
||||
|
||||
case "final":
|
||||
for localKey, chunkKey := range fields {
|
||||
val := gjson.New(chunk).Get(chunkKey.(string))
|
||||
if !val.IsNil() {
|
||||
if _, exists := result[aggregateTo]; !exists {
|
||||
result[aggregateTo] = make(map[string]any)
|
||||
}
|
||||
result[aggregateTo].(map[string]any)[localKey] = val.Val()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(partials) > 0 {
|
||||
for _, evt := range events {
|
||||
e, _ := evt.(map[string]any)
|
||||
if e["type"] == "partial" {
|
||||
if orderBy, ok := e["order_by"].(string); ok {
|
||||
sort.Slice(partials, func(i, j int) bool {
|
||||
return fmt.Sprint(partials[i][orderBy]) < fmt.Sprint(partials[j][orderBy])
|
||||
})
|
||||
}
|
||||
result[e["aggregate_to"].(string)] = partials
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mergedBytes, _ := json.Marshal(result)
|
||||
return gjson.New(mergedBytes).Map(), nil
|
||||
}
|
||||
@@ -32,6 +32,7 @@ type asynchModelCol struct {
|
||||
TokenConfig string
|
||||
ExtendMapping string
|
||||
QueryConfig string
|
||||
StreamConfig string
|
||||
}
|
||||
|
||||
var AsynchModelCol = asynchModelCol{
|
||||
@@ -64,6 +65,7 @@ var AsynchModelCol = asynchModelCol{
|
||||
TokenConfig: "token_config",
|
||||
ExtendMapping: "extend_mapping",
|
||||
QueryConfig: "query_config",
|
||||
StreamConfig: "stream_config",
|
||||
}
|
||||
|
||||
// AsynchModel 异步模型配置
|
||||
@@ -97,4 +99,5 @@ type AsynchModel struct {
|
||||
TokenConfig map[string]any `orm:"token_config" json:"tokenConfig"`
|
||||
ExtendMapping map[string]any `orm:"extend_mapping" json:"extendMapping"`
|
||||
QueryConfig map[string]any `orm:"query_config" json:"queryConfig"`
|
||||
StreamConfig map[string]any `orm:"stream_config" json:"streamConfig"`
|
||||
}
|
||||
|
||||
@@ -124,14 +124,61 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
return
|
||||
}
|
||||
|
||||
// 4) 调用模型(不重试,失败直接回调)
|
||||
textResult, err := w.callModel(ctx, t, model, payload)
|
||||
// 4) 调用模型
|
||||
var textResult map[string]any
|
||||
if streamEnabled, _ := model.StreamConfig["enabled"].(bool); streamEnabled {
|
||||
rawBytes, modelErr := w.callModelRaw(ctx, t, model, payload)
|
||||
if modelErr != nil {
|
||||
w.failTask(ctx, t, modelErr.Error())
|
||||
return
|
||||
}
|
||||
textResult, err = util.ParseStreamResponse(rawBytes, model.StreamConfig)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
textResult, err = w.callModel(ctx, t, model, payload)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 5) 模型返回映射处理
|
||||
textResult, err = util.MapResponsePayload(model.ResponseMapping, textResult)
|
||||
if err != nil {
|
||||
w.failTask(ctx, t, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) 上传 OSS(可重试)
|
||||
// 6) 保存临时文件(区分二进制音频和JSON文本)
|
||||
if audioData, ok := textResult["audio"].([]byte); ok {
|
||||
tmpPath, tmpErr := saveTmpResult(t.TaskID, audioData, ".mp3")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
if t.TmpFile != "" {
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
} else {
|
||||
mappedBytes, _ := json.Marshal(textResult)
|
||||
if len(mappedBytes) > 0 {
|
||||
tmpPath, tmpErr := saveTmpResult(t.TaskID, mappedBytes, ".json")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
if t.TmpFile != "" {
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 7) 上传 OSS(可重试)
|
||||
var oss *gateway.UploadFileResponse
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
@@ -150,35 +197,35 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
}
|
||||
}
|
||||
|
||||
// 6) 解析校验(可重试,失败重新调模型)
|
||||
//if req.BuildType == 1 {
|
||||
// for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
// if attempt > 0 {
|
||||
// g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||
// }
|
||||
// // 6.1) 校验数据
|
||||
// err = util.ValidatePromptResult(textResult, model)
|
||||
// if err == nil {
|
||||
// break
|
||||
// }
|
||||
// g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||||
// t.TaskID, attempt, maxRetry, err)
|
||||
// if attempt == maxRetry {
|
||||
// w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||
// return
|
||||
// }
|
||||
// // 6.2) 重新调模型
|
||||
// newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||||
// if modelErr != nil {
|
||||
// g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||||
// t.TaskID, attempt, maxRetry, modelErr)
|
||||
// continue
|
||||
// }
|
||||
// textResult = newResult
|
||||
// }
|
||||
//}
|
||||
//8) 解析校验(可重试,失败重新调模型)
|
||||
if req.BuildType == 1 {
|
||||
for attempt := 0; attempt <= maxRetry; attempt++ {
|
||||
if attempt > 0 {
|
||||
g.Log().Infof(ctx, "[执行任务][重试] JSON解析 第%d/%d次 taskId=%s", attempt, maxRetry, t.TaskID)
|
||||
}
|
||||
// 6.1) 校验数据
|
||||
err = util.ValidatePromptResult(textResult, model)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
g.Log().Warningf(ctx, "[执行任务][解析失败] taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, err)
|
||||
if attempt == maxRetry {
|
||||
w.failTask(ctx, t, fmt.Sprintf("JSON解析重试耗尽: %v", err))
|
||||
return
|
||||
}
|
||||
// 6.2) 重新调模型
|
||||
newResult, modelErr := w.callModel(ctx, t, model, payload)
|
||||
if modelErr != nil {
|
||||
g.Log().Warningf(ctx, "[执行任务][重试] 重新调模型失败 taskId=%s attempt=%d/%d err=%v",
|
||||
t.TaskID, attempt, maxRetry, modelErr)
|
||||
continue
|
||||
}
|
||||
textResult = newResult
|
||||
}
|
||||
}
|
||||
|
||||
// 7) 成功回调
|
||||
// 9) 成功回调
|
||||
t.State = 2
|
||||
t.OssFile = oss.FileAddressPrefix + oss.FileURL
|
||||
t.FileType = oss.FileFormat
|
||||
@@ -199,9 +246,40 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask, req *
|
||||
|
||||
g.Log().Infof(ctx, "[执行任务][成功] taskId=%s fileType=%s textLen=%d callbackUrl=%s",
|
||||
t.TaskID, oss.FileFormat, len(textResult), t.CallbackURL)
|
||||
|
||||
// 10) 删除临时文件
|
||||
_ = os.Remove(t.TmpFile)
|
||||
}
|
||||
|
||||
// callModelRaw 调用模型,返回原始字节(不做响应映射,用于流式输出)
|
||||
func (w *asyncWorker) callModelRaw(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) ([]byte, error) {
|
||||
var data []byte
|
||||
var err error
|
||||
|
||||
if task.Phase == 1 && strings.TrimSpace(task.TmpFile) != "" {
|
||||
data, err = os.ReadFile(task.TmpFile)
|
||||
if err != nil || len(data) == 0 {
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
_ = dao.Stat.IncRequestCount(ctx, time.Now(), int64(task.TenantId), task.Creator, task.ModelName)
|
||||
data, err = InvokeModel(ctx, model, payload, task.ModelKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpPath, tmpErr := saveTmpResult(task.TaskID, data, "")
|
||||
if tmpErr == nil && tmpPath != "" {
|
||||
task.TmpFile = tmpPath
|
||||
task.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, task.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// 返回: ossURL(成功时有值), fileType, textResult(失败时是错误信息), retryable(是否可重试)
|
||||
// callModel 调用模型 + 检测文件类型 + 保存临时文件
|
||||
func (w *asyncWorker) callModel(ctx context.Context, task *entity.AsynchTask, model *entity.AsynchModel, payload map[string]any) (map[string]any, error) {
|
||||
@@ -302,14 +380,7 @@ func InvokeModel(ctx context.Context, model *entity.AsynchModel, payload map[str
|
||||
msg := string(b)
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
|
||||
// 8)响应参数映射
|
||||
mappedResponse, err := util.MapResponsePayload(model.ResponseMapping, b)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
return b, nil
|
||||
}
|
||||
return mappedResponse, nil
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// // InvokeModel 调用模型服务,返回二进制结果
|
||||
|
||||
11
update.sql
11
update.sql
@@ -42,15 +42,8 @@ CREATE TABLE IF NOT EXISTS asynch_models (
|
||||
remark TEXT DEFAULT '' -- 备注
|
||||
response_token_field VARCHAR(128) NOT NULL DEFAULT ''; -- 响应中消耗token的字段映射
|
||||
operator_name VARCHAR(64) NOT NULL DEFAULT '', -- 运营商名称
|
||||
token_config JSONB NOT NULL DEFAULT '{
|
||||
"zh_ratio": 1.0,
|
||||
"en_ratio": 1.3,
|
||||
"space_ratio": 0.1,
|
||||
"punctuation_ratio": 0.1,
|
||||
"max_window_size": 8192,
|
||||
"reserve_ratio": 0.2,
|
||||
"min_reserve": 512,
|
||||
}'::jsonb -- Token配置
|
||||
stream_config JSONB NOT NULL DEFAULT '{}'::jsonb, -- 流式配置
|
||||
token_config JSONB NOT NULL DEFAULT '{}'::jsonb -- Token配置
|
||||
extend_mapping JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
query_config JSONB NOT NULL DEFAULT '{}'::jsonb;
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user